nn.LSTM层出来的out和hn的关系
程序员文章站
2024-03-24 23:31:52
...
import torch
import torch.nn as nn
单层lstm
lstm = nn.LSTM(input_size=100, hidden_size=200, bidirectional=True, batch_first=True)
a = torch.randn(32, 512, 100)
out, (h, c) = lstm(a)
print(out.shape) # 32, 512, 400
print(h.shape) # 2, 32, 400
print(out[0, -1, :200] == h[0, 0, :]) # True
print(out[0, 0, 200:] == h[1, 0, :]) # True
多层lstm
lstm = nn.LSTM(input_size=100, hidden_size=200, num_layers=3, bidirectional=True, batch_first=True)
a = torch.randn(32, 512, 100)
out, (h, c) = lstm(a)
print(out.shape) # 32, 512, 400
print(h.shape) # 6, 32, 200
print(out[0, -1, :200] == h[-2, 0, :]) # True
print(out[0, 0, 200:] == h[-1, 0, :]) # True