您现在的位置是: 首页

pytorch nn.LSTM()参数详解

程序员文章站 2024-03-24 23:27:04
import torch
import torch.nn as nn
from torch.autograd import Variable

rnn = nn.LSTM(10,20,2)  #构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. 
The input can also be a packed variable length sequence.
input shape(a,b,c)
a:seq_len  -> 序列长度
c:input_size   输入特征数目 
inputs = torch.randn(5,3,10)
tensor containing the initial hidden state for each element in the batch
h_0 of shape:num_layers * num_directions, batch, hidden_size
单向LSTM:num_directions=1   双向LSTM:num_directions=2
h0 = torch.randn(2,3,20)
c_0 of shape (num_layers * num_directions, batch, hidden_size)
c0 = torch.randn(2,3,20)

Outputs: output, (h_n, c_n)
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n (num_layers * num_directions, batch, hidden_size) 
output,(hn,cn) = rnn(inputs,(h0,c0))
相关标签: pytorch LSTM