torch.nn.LSTM()函数维度详解
程序员文章站
2024-03-24 23:48:52
...
import torch
import torch.nn as nn
lstm = nn.LSTM(10, 20, 2)
x = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn)=lstm(x, (h0, c0))
# output.shape torch.Size([5, 3, 20])
# hn.shape torch.Size([2, 3, 20])
# cn.shape torch.Size([5, 3, 20])
lstm=nn.LSTM(input_size, hidden_size, num_layers)
x seq_len, batch, input_size
h0 num_layersnum_directions, batch, hidden_size
c0 num_layersnum_directions, batch, hidden_size
output seq_len, batch, num_directionshidden_size
hn num_layersnum_directions, batch, hidden_size
cn num_layersnum_directions, batch, hidden_size
上一篇: TensorFlow正弦函数预测