欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

nn.LSTM备忘

程序员文章站 2024-03-24 23:44:10
...

参考自:https://blog.csdn.net/m0_45478865/article/details/104455978

# (input_dim, output_dim, layers,  )
lstm = nn.LSTM(768, 768, 1, bidirectional=True, batch_first=True)
# (batch_size, seq_len, hidden_size)
x = torch.randn(2, 512, 768)
result = lstm(x)
len(result) # 2 : (output, (hn, cn))
output, (hn, cn) = result

 

output.shape # torch.Size([2, 512, 1536])
hn.shape # hn代表最后一个单元,torch.Size([2, 2, 768])
cn.shape # torch.Size([2, 2, 768])