欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch LSTM

程序员文章站 2024-03-24 23:06:28
...

 

#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)  #使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)      #将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze()         #去掉(batch,1, 28,28)中的1,变成(batch,28, 28)
     x = x.permute(2, 0, 1)  #将最后一维放到第一维,变成(28, batch,28)
     out, _ = self.rnn(x)    #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]       #取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out)   #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)

#author: www

 

squeeze()

相关标签: PyTorch