欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch nn.LSTM()参数详解

程序员文章站 2024-03-24 23:27:04
...
import torch
import torch.nn as nn
from torch.autograd import Variable

rnn = nn.LSTM(10,20,2)  #构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
'''
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. 
The input can also be a packed variable length sequence.
input shape(a,b,c)
a:seq_len  -> 序列长度
b:batch
c:input_size   输入特征数目 
'''
inputs = torch.randn(5,3,10)
print(inputs)
'''
tensor containing the initial hidden state for each element in the batch
h_0 of shape:num_layers * num_directions, batch, hidden_size
单向LSTM:num_directions=1   双向LSTM:num_directions=2
'''
h0 = torch.randn(2,3,20)
'''
c_0 of shape (num_layers * num_directions, batch, hidden_size)
'''
c0 = torch.randn(2,3,20)

'''
Outputs: output, (h_n, c_n)
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n (num_layers * num_directions, batch, hidden_size) 
'''
output,(hn,cn) = rnn(inputs,(h0,c0))
相关标签: pytorch LSTM