class torch.nn.LSTM( args, * kwargs)[source]详细解读代码
程序员文章站
2022-03-16 17:17:58
...
小白代码解读
import torch.nn as nn
import torch
from torch.autograd import Variable
lstm = nn.LSTM(10,20,2) # (1)输入的特征维度10列 (2)隐状态的特征维度20列 (3)num_layers = 2层
# print(lstm)
# print("***************************************************")
# 输入
input = Variable(torch.randn(5,3,10)) # 5行矩阵 每个矩阵是3行10列;10列是根据LSTM输入规定10列
#print(input)
# print("++++++++++++++++++++++++++++++++++++++++++++++++++")
# 保存着batch中每个元素的初始化隐状态的Tensor
h0 = Variable(torch.randn(2,3,20)) # 2行矩阵,每个矩阵是3行20列;2是LSTM中2层规定,20列是LSTM输入规定20列
# 保存着batch中每个元素的初始化细胞状态的Tensor
c0 = Variable(torch.randn(2,3,20))# 2行矩阵,每个矩阵是3行20列;2是LSTM中2层规定,20列是LSTM输入规定20列
# 输出
output , hn = lstm(input,(h0,c0))
# print(output)
# print("--------------------------------------------------")
print(hn)
上一篇: ORB_SLAM2源码解析-框架
下一篇: Android BLE蓝牙使用 一