Pytorch中的LSTM详细代码解读
程序员文章站
2022-03-01 20:54:21
...
小白撸代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
lstm = nn.LSTM(3,3) #输入dim = 3 输出dim = 3 格式是输入的是3列 隐藏层数也是3列
#print(lstm)
#输入
inputs = [torch.randn(1,3) for _ in range(5)] #序列长度为5 1行3列的张量;3列是LSTM输入的3列
#print(inputs)
#初始化隐层状态
hidden = (torch.randn(1,1,3),#;1大行2大列矩阵,每大列矩阵是1行3列
torch.randn(1,1,3))
#print(hidden)
for i in inputs:
#print(i ,"================")
#print(i.view(1,1,-1),"+++++++++++")
#一步一个元素地通过序列。
#在每一步之后,隐藏层包含了隐藏的状态
#输出 在view中1个矩阵1行-1是每个矩阵的所有元素
#print(hidden,"~~~~~~~~~~~~~~~~~~")
out, hidden = lstm(i.view(1,1,-1),hidden)
#print(out,"$$$$$$$$$$$$$$$$$$$$$")
#view 返回一个有相同数据但大小不同的tensor。
#返回的tensor必须与原tensor相同的数据和相同数目的元素,但可以有不同的大小。
inputs = torch.cat(inputs).view(len(inputs),1,-1)
#print(len(inputs))
hidden = (torch.randn(1,1,3),torch.randn(1,1,3))#清理hidden状态
#输出
out,hidden =lstm(inputs ,hidden)
#print(out)
#print(hidden)
上一篇: Android BLE蓝牙详细解读
下一篇: 详细解读ORBSLAM中的描述子提取过程