欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Pytorch实现循环神经网络(二)、LSTM实现MNIST手写数据集分类

程序员文章站 2022-03-23 09:12:47
先放上完整的训练测试代码:# -*- coding: utf-8 -*-"""Created on Fri Aug 7 15:10:16 2020@author: wj"""import torchfrom torch import nnimport torchvision.datasets as datasetsimport torchvision.transforms as transformsfrom torch.autograd import Variableimport...

先放上完整的训练测试代码:

# -*- coding: utf-8 -*-
"""
Created on Fri Aug  7 15:10:16 2020

@author: wj
"""
import torch
from torch import nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
# torch.manual_seed(1)    # reproducible


num_epoches = 10
BATCH_SIZE = 128         #批训练的数量
TIME_STEP = 28          # 相当于序列长度
INPUT_SIZE = 28         # 特征向量长度
LR = 0.01               # learning rate

# MNIST数据集下载
train_dataset = datasets.MNIST(
    root='./data', train=True, transform=transforms.ToTensor(), download=True)

test_dataset = datasets.MNIST(
    root='./data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)


# 定义网络模型
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(input_size=INPUT_SIZE,  # if use nn.RNN(), it hardly learns
                           hidden_size=64,         # rnn 隐藏单元
                           num_layers=1,           # rnn 层数
                           batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, seq, input_size)
                          )
        self.out = nn.Linear(64, 10)  #10分类

    def forward(self, x):

        r_out, (h_n, h_c) = self.rnn(x, None)   # None represents zero initial hidden state
        # choose r_out at the last time step
        out = self.out(r_out[:, -1, :])
        return out

rnn = RNN() # 实例化
rnn=rnn.cuda()
print(rnn)  # 查看模型结构

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # 选择优化器,optimize all cnn parameters
criterion = nn.CrossEntropyLoss()                       # 定义损失函数,the target label is not one-hotted


for epoch in range(num_epoches):
    print('epoch {}'.format(epoch + 1))
    print('*' * 10)
    running_loss = 0.0
    running_acc = 0.0
    
    rnn.train()
    for imgs, labels in train_loader:    
        imgs = imgs.squeeze(1)#(N,28,28)
        imgs = Variable(imgs.cuda())
        labels = Variable(labels.cuda())
        # 前向传播
        out = rnn(imgs)
        loss = criterion(out, labels)
        running_loss += loss.item() * labels.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred == labels).sum()
        running_acc += num_correct.item()
        # 向后传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
       
        
        
    print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
        epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(train_dataset))))
    
    rnn.eval() 
    eval_loss = 0.0
    eval_acc = 0.0      
    for imgs, labels in test_loader:
        imgs = imgs.squeeze(1)#(N,28,28)
        imgs=imgs.cuda()
        labels=labels.cuda()
        
        out = rnn(imgs)
        loss = criterion(out, labels)
        eval_loss += loss.item() * labels.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred == labels).sum()
        eval_acc += num_correct.item()
    print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
        test_dataset)), eval_acc / (len(test_dataset))))

最后放上训练结果:
Pytorch实现循环神经网络(二)、LSTM实现MNIST手写数据集分类

本文地址:https://blog.csdn.net/weixin_45738220/article/details/107881269