欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch中的LSTM详细代码解读

程序员文章站 2022-03-01 20:54:21
...

小白撸代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

lstm = nn.LSTM(3,3) #输入dim = 3 输出dim = 3 格式是输入的是3列 隐藏层数也是3列
#print(lstm)
#输入
inputs = [torch.randn(1,3) for _ in range(5)] #序列长度为5  1行3列的张量;3列是LSTM输入的3列
#print(inputs)

#初始化隐层状态
hidden = (torch.randn(1,1,3),#;1大行2大列矩阵,每大列矩阵是1行3列
          torch.randn(1,1,3))
#print(hidden)
for i in inputs:
    #print(i ,"================")

    #print(i.view(1,1,-1),"+++++++++++")
          #一步一个元素地通过序列。
    #在每一步之后,隐藏层包含了隐藏的状态
    #输出   在view中1个矩阵1行-1是每个矩阵的所有元素
    #print(hidden,"~~~~~~~~~~~~~~~~~~")
    out, hidden = lstm(i.view(1,1,-1),hidden)
    #print(out,"$$$$$$$$$$$$$$$$$$$$$")

#view 返回一个有相同数据但大小不同的tensor。
#返回的tensor必须与原tensor相同的数据和相同数目的元素,但可以有不同的大小。
inputs = torch.cat(inputs).view(len(inputs),1,-1)
#print(len(inputs))
hidden = (torch.randn(1,1,3),torch.randn(1,1,3))#清理hidden状态
#输出
out,hidden =lstm(inputs ,hidden)
#print(out)
#print(hidden)
相关标签: pytorch