pytorch nn.LSTM()参数详解
程序员文章站
2024-03-24 23:27:04
...
import torch
import torch.nn as nn
from torch.autograd import Variable
rnn = nn.LSTM(10,20,2) #构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
'''
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence.
The input can also be a packed variable length sequence.
input shape(a,b,c)
a:seq_len -> 序列长度
b:batch
c:input_size 输入特征数目
'''
inputs = torch.randn(5,3,10)
print(inputs)
'''
tensor containing the initial hidden state for each element in the batch
h_0 of shape:num_layers * num_directions, batch, hidden_size
单向LSTM:num_directions=1 双向LSTM:num_directions=2
'''
h0 = torch.randn(2,3,20)
'''
c_0 of shape (num_layers * num_directions, batch, hidden_size)
'''
c0 = torch.randn(2,3,20)
'''
Outputs: output, (h_n, c_n)
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n (num_layers * num_directions, batch, hidden_size)
'''
output,(hn,cn) = rnn(inputs,(h0,c0))
上一篇: 异常处理机制
下一篇: Android 蓝牙通信