欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch——LSTM

程序员文章站 2024-03-24 23:36:04
...

nn.LSTM

在LSTM中,c和h的size是一样的

pytorch——LSTM

pytorch——LSTM

import  torch
from  torch import nn
import numpy as np

lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
print(lstm)
x = torch.randn(10,3,100) #3个句子,每个句子10个单词,每个单词encoding成100维的vector
out,(h,c) = lstm(x)
print(out.shape,h.shape,c.shape)

pytorch——LSTM

 

nn.LSTMCell

第二种方式,灵活性更大的cell,人为来决定每一次喂数据

pytorch——LSTM

pytorch——LSTM

单层

import  torch
from  torch import nn
import numpy as np

print('one layer lstm')
cell=nn.LSTMCell(input_size=100, hidden_size=20)
h=torch.zeros(3,20)
c=torch.zeros(3,20)
x = torch.randn(10,3,100)
for xt in x: 
	h,c = cell(xt, [h,c])

print('h.shape: ',h.shape)
print('c.shape: ',c.shape)

pytorch——LSTM

双层

import  torch
from  torch import nn
import numpy as np

x = torch.randn(10,3,100)
print('two layer lstm')
cell1=nn.LSTMCell(input_size=100, hidden_size=30)
cell2=nn.LSTMCell(input_size=30, hidden_size=20)
h1=torch. zeros(3,30)
c1=torch. zeros(3,30)
h2=torch. zeros(3,20)
c2=torch. zeros(3,20)
for xt in x: 
	h1,c1=cell1(xt,[h1, c1])
	h2,c2=cell2(h1,[h2, c2])
print('h.shape: ',h2.shape)
print('c.shape: ',c2.shape)

pytorch——LSTM

 

相关标签: pytorch LSTM