pytorch——LSTM
程序员文章站
2024-03-24 23:36:04
...
nn.LSTM
在LSTM中,c和h的size是一样的
import torch from torch import nn import numpy as np lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4) print(lstm) x = torch.randn(10,3,100) #3个句子,每个句子10个单词,每个单词encoding成100维的vector out,(h,c) = lstm(x) print(out.shape,h.shape,c.shape)
nn.LSTMCell
第二种方式,灵活性更大的cell,人为来决定每一次喂数据
单层
import torch from torch import nn import numpy as np print('one layer lstm') cell=nn.LSTMCell(input_size=100, hidden_size=20) h=torch.zeros(3,20) c=torch.zeros(3,20) x = torch.randn(10,3,100) for xt in x: h,c = cell(xt, [h,c]) print('h.shape: ',h.shape) print('c.shape: ',c.shape)
双层
import torch from torch import nn import numpy as np x = torch.randn(10,3,100) print('two layer lstm') cell1=nn.LSTMCell(input_size=100, hidden_size=30) cell2=nn.LSTMCell(input_size=30, hidden_size=20) h1=torch. zeros(3,30) c1=torch. zeros(3,30) h2=torch. zeros(3,20) c2=torch. zeros(3,20) for xt in x: h1,c1=cell1(xt,[h1, c1]) h2,c2=cell2(h1,[h2, c2]) print('h.shape: ',h2.shape) print('c.shape: ',c2.shape)