欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 中的torch.nn.LSTM函数

程序员文章站 2024-03-24 23:10:16
...

LSTM是RNN的一种变体
主要包括以下几个参数:
input_size:输入的input中的参数维度,即文本中的embedding_dim
hidden_size:隐藏层的维度
num_layers:LSTM的层数,一般为2-3层,默认为1
bias:是否使用偏置向,默认为True
batch_first:是否输入的input第一个为batch_size,pytorch默认False,即输入的input的三维张量是seq_len放在第一个
dropout:是否丢弃部分神经元,默认为0
bidirectional:是否使用双向LSTM ,默认False

输入:inputs,(h0,c0)
其中inputst是一个三维张量
主要包括[batch_size,seq_len,input_size]
h0是0时刻的隐层,默认为全0
c0是0时刻的cell状态,默认为全0
h0,c0的维度都为:[batch_size,num_layers*num_directions,hidden_size]

输出:outputs,(hn,cn)
output的维度[batch_size,seq_len,num_directions*hidden_size]
hn和cn是第n时刻的隐层和cell状态,维度和h0,c0相同。

下面是代码示例:

Talk is cheap.Show me the code.

input:
假设输入是[64,512,100]
LSTM = nn.LSTM(100,128,batch_first=True)
x1 = torch.randn([64,512,100)
output,(hn,cn) = LSTM(x1)

output.shape的shape[batch,seq_len,num_directions*hidden_size])
[64, 512, 128]
hn,cn的维度均为[num_layers * num_directions,batch,hidden_size]
[1,64,128]

如果是LSTM = nn.LSTM(100,128,batch_first=True,directional=True)
LSTM = nn.LSTM(100,128,batch_first=True)
x1 = torch.randn([64,512,100)
output,(hn,cn) = LSTM(x1)

那么output的维度将变成[64,512,256]
hn,cn的维度会变成[2,64,128]

相关标签: 算法