使用PyTorch实现MNIST手写体识别代码
程序员文章站
2022-07-18 21:53:36
实验环境
win10 + anaconda + jupyter notebook
pytorch1.1.0
python3.7
gpu环境(可选)
mnist数据集介绍...
实验环境
win10 + anaconda + jupyter notebook
pytorch1.1.0
python3.7
gpu环境(可选)
mnist数据集介绍
mnist 包括6万张28x28的训练样本,1万张测试样本,可以说是cv里的“hello word”。本文使用的cnn网络将mnist数据的识别率提高到了99%。下面我们就开始进行实战。
导入包
import torch import torch.nn as nn import torch.nn.functional as f import torch.optim as optim from torchvision import datasets, transforms torch.__version__
定义超参数
batch_size=512 epochs=20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
数据集
我们直接使用pytorch中自带的dataset,并使用dataloader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择false
train_loader = torch.utils.data.dataloader( datasets.mnist('data', train=true, download=true, transform=transforms.compose([ transforms.totensor(), transforms.normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=true) test_loader = torch.utils.data.dataloader( datasets.mnist('data', train=false, transform=transforms.compose([ transforms.totensor(), transforms.normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=true)
定义网络
该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。
class convnet(nn.module): def __init__(self): super().__init__() self.conv1=nn.conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) self.conv2=nn.conv2d(10,20,3) # input:(10,12,12) output:(20,10,10) self.fc1 = nn.linear(20*10*10,500) self.fc2 = nn.linear(500,10) def forward(self,x): in_size = x.size(0) out = self.conv1(x) out = f.relu(out) out = f.max_pool2d(out, 2, 2) out = self.conv2(out) out = f.relu(out) out = out.view(in_size,-1) out = self.fc1(out) out = f.relu(out) out = self.fc2(out) out = f.log_softmax(out,dim=1) return out
实例化网络
model = convnet().to(device) # 将网络移动到gpu上 optimizer = optim.adam(model.parameters()) # 使用adam优化器
定义训练函数
def train(model, device, train_loader, optimizer, epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = f.nll_loss(output, target) loss.backward() optimizer.step() if(batch_idx+1)%30 == 0: print('train epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
定义测试函数
def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += f.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加 pred = output.max(1, keepdim=true)[1] # 找到概率最大的下标 correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\ntest set: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
开始训练
for epoch in range(1, epochs + 1): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader)
实验结果
train epoch: 1 [14848/60000 (25%)] loss: 0.375058 train epoch: 1 [30208/60000 (50%)] loss: 0.255248 train epoch: 1 [45568/60000 (75%)] loss: 0.128060 test set: average loss: 0.0992, accuracy: 9690/10000 (97%) train epoch: 2 [14848/60000 (25%)] loss: 0.093066 train epoch: 2 [30208/60000 (50%)] loss: 0.087888 train epoch: 2 [45568/60000 (75%)] loss: 0.068078 test set: average loss: 0.0599, accuracy: 9816/10000 (98%) train epoch: 3 [14848/60000 (25%)] loss: 0.043926 train epoch: 3 [30208/60000 (50%)] loss: 0.037321 train epoch: 3 [45568/60000 (75%)] loss: 0.068404 test set: average loss: 0.0416, accuracy: 9859/10000 (99%) train epoch: 4 [14848/60000 (25%)] loss: 0.031654 train epoch: 4 [30208/60000 (50%)] loss: 0.041341 train epoch: 4 [45568/60000 (75%)] loss: 0.036493 test set: average loss: 0.0361, accuracy: 9873/10000 (99%) train epoch: 5 [14848/60000 (25%)] loss: 0.027688 train epoch: 5 [30208/60000 (50%)] loss: 0.019488 train epoch: 5 [45568/60000 (75%)] loss: 0.018023 test set: average loss: 0.0344, accuracy: 9875/10000 (99%) train epoch: 6 [14848/60000 (25%)] loss: 0.024212 train epoch: 6 [30208/60000 (50%)] loss: 0.018689 train epoch: 6 [45568/60000 (75%)] loss: 0.040412 test set: average loss: 0.0350, accuracy: 9879/10000 (99%) train epoch: 7 [14848/60000 (25%)] loss: 0.030426 train epoch: 7 [30208/60000 (50%)] loss: 0.026939 train epoch: 7 [45568/60000 (75%)] loss: 0.010722 test set: average loss: 0.0287, accuracy: 9892/10000 (99%) train epoch: 8 [14848/60000 (25%)] loss: 0.021109 train epoch: 8 [30208/60000 (50%)] loss: 0.034845 train epoch: 8 [45568/60000 (75%)] loss: 0.011223 test set: average loss: 0.0299, accuracy: 9904/10000 (99%) train epoch: 9 [14848/60000 (25%)] loss: 0.011391 train epoch: 9 [30208/60000 (50%)] loss: 0.008091 train epoch: 9 [45568/60000 (75%)] loss: 0.039870 test set: average loss: 0.0341, accuracy: 9890/10000 (99%) train epoch: 10 [14848/60000 (25%)] loss: 0.026813 train epoch: 10 [30208/60000 (50%)] loss: 0.011159 train epoch: 10 [45568/60000 (75%)] loss: 0.024884 test set: average loss: 0.0286, accuracy: 9901/10000 (99%) train epoch: 11 [14848/60000 (25%)] loss: 0.006420 train epoch: 11 [30208/60000 (50%)] loss: 0.003641 train epoch: 11 [45568/60000 (75%)] loss: 0.003402 test set: average loss: 0.0377, accuracy: 9894/10000 (99%) train epoch: 12 [14848/60000 (25%)] loss: 0.006866 train epoch: 12 [30208/60000 (50%)] loss: 0.012617 train epoch: 12 [45568/60000 (75%)] loss: 0.008548 test set: average loss: 0.0311, accuracy: 9908/10000 (99%) train epoch: 13 [14848/60000 (25%)] loss: 0.010539 train epoch: 13 [30208/60000 (50%)] loss: 0.002952 train epoch: 13 [45568/60000 (75%)] loss: 0.002313 test set: average loss: 0.0293, accuracy: 9905/10000 (99%) train epoch: 14 [14848/60000 (25%)] loss: 0.002100 train epoch: 14 [30208/60000 (50%)] loss: 0.000779 train epoch: 14 [45568/60000 (75%)] loss: 0.005952 test set: average loss: 0.0335, accuracy: 9897/10000 (99%) train epoch: 15 [14848/60000 (25%)] loss: 0.006053 train epoch: 15 [30208/60000 (50%)] loss: 0.002559 train epoch: 15 [45568/60000 (75%)] loss: 0.002555 test set: average loss: 0.0357, accuracy: 9894/10000 (99%) train epoch: 16 [14848/60000 (25%)] loss: 0.000895 train epoch: 16 [30208/60000 (50%)] loss: 0.004923 train epoch: 16 [45568/60000 (75%)] loss: 0.002339 test set: average loss: 0.0400, accuracy: 9893/10000 (99%) train epoch: 17 [14848/60000 (25%)] loss: 0.004136 train epoch: 17 [30208/60000 (50%)] loss: 0.000927 train epoch: 17 [45568/60000 (75%)] loss: 0.002084 test set: average loss: 0.0353, accuracy: 9895/10000 (99%) train epoch: 18 [14848/60000 (25%)] loss: 0.004508 train epoch: 18 [30208/60000 (50%)] loss: 0.001272 train epoch: 18 [45568/60000 (75%)] loss: 0.000543 test set: average loss: 0.0380, accuracy: 9894/10000 (99%) train epoch: 19 [14848/60000 (25%)] loss: 0.001699 train epoch: 19 [30208/60000 (50%)] loss: 0.000661 train epoch: 19 [45568/60000 (75%)] loss: 0.000275 test set: average loss: 0.0339, accuracy: 9905/10000 (99%) train epoch: 20 [14848/60000 (25%)] loss: 0.000441 train epoch: 20 [30208/60000 (50%)] loss: 0.000695 train epoch: 20 [45568/60000 (75%)] loss: 0.000467 test set: average loss: 0.0396, accuracy: 9894/10000 (99%)
总结
一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。
以上这篇使用pytorch实现mnist手写体识别代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
上一篇: 猪蛋白