欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

class torch.nn.LSTM( args, * kwargs)[source]详细解读代码

程序员文章站 2022-03-16 17:17:58
...

小白代码解读

import torch.nn as nn
import torch
from torch.autograd import Variable
lstm =  nn.LSTM(10,20,2) # (1)输入的特征维度10列 (2)隐状态的特征维度20列 (3)num_layers = 2层
# print(lstm)
# print("***************************************************")
# 输入
input = Variable(torch.randn(5,3,10)) # 5行矩阵 每个矩阵是3行10列;10列是根据LSTM输入规定10列
#print(input)
# print("++++++++++++++++++++++++++++++++++++++++++++++++++")
# 保存着batch中每个元素的初始化隐状态的Tensor
h0 = Variable(torch.randn(2,3,20)) # 2行矩阵,每个矩阵是3行20列;2是LSTM中2层规定,20列是LSTM输入规定20列
# 保存着batch中每个元素的初始化细胞状态的Tensor
c0 = Variable(torch.randn(2,3,20))# 2行矩阵,每个矩阵是3行20列;2是LSTM中2层规定,20列是LSTM输入规定20列
# 输出
output , hn = lstm(input,(h0,c0))
# print(output)
# print("--------------------------------------------------")
print(hn)
相关标签: pytorch pytorch