欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

nn.LSTM层出来的out和hn的关系

程序员文章站 2024-03-24 23:31:52
...
import torch
import torch.nn as nn

单层lstm

lstm = nn.LSTM(input_size=100, hidden_size=200, bidirectional=True, batch_first=True)
a = torch.randn(32, 512, 100)
out, (h, c) = lstm(a)
print(out.shape)  # 32, 512, 400
print(h.shape)    # 2, 32, 400
print(out[0, -1, :200] == h[0, 0, :])   # True
print(out[0, 0, 200:] == h[1, 0, :])    # True

多层lstm

lstm = nn.LSTM(input_size=100, hidden_size=200, num_layers=3, bidirectional=True, batch_first=True)
a = torch.randn(32, 512, 100)
out, (h, c) = lstm(a)
print(out.shape)  # 32, 512, 400
print(h.shape)    # 6, 32, 200
print(out[0, -1, :200] == h[-2, 0, :])  # True
print(out[0, 0, 200:] == h[-1, 0, :])   # True