基于PyTorch的mnist数据集的分类
程序员文章站
2023-11-28 16:11:28
基于PyTorch的mnist数据集的分类简介代码实现1.相关包的导入2.数据集加载及处理3.加如LeNet模型及训练模型4.准确率变化图5.测试数据集及可视化预测结果6.Build_LeNet_for_mnist.py7.mnist_loader.py结果展示简介这里本人选用LeNet的卷积神经网络结构实现分类,实验训练10个epoch准确率高达99%,测试集准确率达99%。实现代码中对LeNet网络模型进行了一点改动,且模型代码定义在Build_LeNet_for_mnist.py文件中,数据加载不...
基于PyTorch的mnist数据集的分类
简介
这里本人选用LeNet的卷积神经网络结构实现分类,实验训练10个epoch准确率高达99%,测试集准确率达99%。实现代码中对LeNet网络模型进行了一点改动,且模型代码定义在Build_LeNet_for_mnist.py文件中,数据加载不是从网上下载的数据集,而是加载本地下载的数据集,其加载文件代码为mnist_loader.py,该文件是从pytorch的库文件torchvision.datasets.MNIST中改动的,需改动代码中的urls列表中的数据路径,如我的数据路径如代码中的file:///E:/PyCharmWorkSpace/Image_Set/mnist_data/train-images-idx3-ubyte.gz。代码在显卡上运行,网络中参数设置如代码中所示。
代码实现
1.相关包的导入
import torch
import mnist_loader
import Build_LeNet_for_mnist
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import csv
import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
2.数据集加载及处理
#加载数据集
use_cuda=torch.cuda.is_available()##检测显卡是否可用
batch_size=test_batch_size=32
kwargs={'num_workers':0,'pin_memory':True}if use_cuda else {}
#训练数据加载
train_loader = torch.utils.data.DataLoader(
mnist_loader.MNIST('./mnist_data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), # 第一个参数dataset:数据集
batch_size=batch_size,
shuffle=True, # 随机打乱数据
**kwargs) ##kwargs是上面gpu的设置
#测试数据加载
test_loader = torch.utils.data.DataLoader(
mnist_loader.MNIST('./mnist_data',
train=False, # 如果False,从test.pt创建数据集
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=test_batch_size,
shuffle=True,
**kwargs)
3.加如LeNet模型及训练模型
#加入神经网络及参数设置
learning_rate=0.01
momentum=0.9
device = torch.device("cuda" if use_cuda else "cpu")
model=Build_LeNet_for_mnist.LeNet(1, 10).to(device)#加载模型
optimizer=optim.SGD(model.parameters(),lr=learning_rate,momentum=momentum)#优化器选择
#创建csv文件
csvFile = open("log.csv", "a+")
writer = csv.writer(csvFile) #创建写的对象
last_epoch=0
if os.path.exists("cifar10_cnn.pt"):
print("load pretrain")
model.load_state_dict(torch.load("cifar10_cnn.pt"))
data = pd.read_csv('log.csv')
e = data['epoch']
last_epoch=e[len(e)-1]
else:
print("first train")
#先写入columns_name
writer.writerow(["epoch","acc","loss"])
#训练函数
def train(model, device, train_loader, optimizer, last_epoch,epochs):
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.
print("Train from Epoch: {}".format(last_epoch+1))
model.train() # 进入训练模式
for epoch in range(1+last_epoch, epochs + 1+last_epoch):
correct = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc=100. * correct / len(train_loader.dataset)
print("Train Epoch: {} Accuracy:{:0f}%\tLoss: {:.6f}".format(
epoch,
acc,
loss.item()
))
if acc > best_acc:
best_acc = acc
best_model_wts = copy.deepcopy(model.state_dict())
#print(model.state_dict())
writer.writerow([epoch,acc/100,loss.item()])
return(best_model_wts)
#开始训练和测试
epochs = 10
best_model_wts=train(model, device, train_loader, optimizer,last_epoch, epochs)
csvFile.close()
#保存训练模型
save_model = True
if (save_model):
torch.save(best_model_wts,"mnist_LeNet.pt")
#词典格式,model.state_dict()只保存模型参数
4.准确率变化可视化
#可视化准确率
data = pd.read_csv('log.csv')
epoch = data['epoch']
acc = data['acc']
loss = data['loss']
fig=plt.gcf()
fig.set_size_inches(10,4)
plt.title("Accuracy&Loss")
plt.xlabel("Training Epochs")
plt.ylabel("Value")
plt.plot(epoch,acc,label="Accuracy")
#plt.plot(epoch,loss,label="Loss")
plt.ylim((0,1.))
plt.xticks(np.arange(1, len(epoch+1), 1.0))
plt.yticks(np.arange(0, 1.5, 0.2))
plt.legend()
plt.show()
5.测试数据集及可视化预测结果
def test(model, device, test_loader):
model.eval() # 进入测试模式
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
data_record=data[0:10]
pred_record=pred.view_as(target)[0:10].cpu().numpy()
target_record=target[0:10].cpu().numpy()
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return data_record,pred_record,target_record
data_record,pred_record,target_record=test(model, device, test_loader)
#可视化测试分类结果
#unloader = transforms.ToPILImage()
label_dict={0:"0",1:"1",2:"2",3:"3",4:"4",5:"5",6:"6",7:"7",8:"8",9:"9"}
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):
fig=plt.gcf()
fig.set_size_inches(12,6)
if num>10:
num=10
for i in range(0,num):
image = images[idx].cpu().clone()
image = image.squeeze(0)
#image = unloader(image)
ax=plt.subplot(2,5,1+i)
ax.imshow(image,cmap="binary")
title=label_dict[labels[idx]]
if len(prediction)>0:
title+="=>"+label_dict[prediction[idx]]
ax.set_title(title,fontsize=10)
idx+=1
plt.show()
plot_images_labels_prediction(data_record,target_record,pred_record,0,10)
6.Build_LeNet_for_mnist.py
import torch.nn as nn
import torch.nn.functional as F
#建立神经网络
class LeNet(nn.Module):
def __init__(self,channel,classes):
super(LeNet, self).__init__()
self.conv1=nn.Conv2d(channel,32,5,1)
self.conv2=nn.Conv2d(32,64,5,1)
self.fc1=nn.Linear(4*4*64,512)
self.fc2=nn.Linear(512,classes)
def forward(self,x):
x=F.relu(self.conv1(x))
x=F.max_pool2d(x,2,2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*64)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
7.mnist_loader.py
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
class MNIST(data.Dataset):
urls = [#此出需更改为自己电脑中数据集的路径,数据集的路径能在电脑浏览器中打开
'file:///E:/PyCharmWorkSpace/Image_Set/mnist_data/train-images-idx3-ubyte.gz',
'file:///E:/PyCharmWorkSpace/Image_Set/mnist_data/train-labels-idx1-ubyte.gz',
'file:///E:/PyCharmWorkSpace/Image_Set/mnist_data/t10k-images-idx3-ubyte.gz',
'file:///E:/PyCharmWorkSpace/Image_Set/mnist_data/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' + ' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
for url in self.urls:
print('Downloading ' + url)
data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path)
# process and save as torch files
print('Processing...')
training_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)
print('Done!')
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def read_label_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()
def read_image_file(path):
with open(path, 'rb') as f:
data = f.read()
assert get_int(data[:4]) == 2051
length = get_int(data[4:8])
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)
结果展示
准确率变化图效果
测试数据集准确率及预测结果图
本文地址:https://blog.csdn.net/pengshunbetter/article/details/107109420
下一篇: 在网上赚钱的项目,适合新手
推荐阅读
-
基于PyTorch的mnist数据集的分类
-
pytorch + visdom CNN处理自建图片数据集的方法
-
pytorch实现建立自己的数据集(以mnist为例)
-
基于MNIST手写数字数据集的数字识别小程序
-
基于jupyter notebook的python编程(Win10通过OpenCv-3.4.1进行人脸口罩数据集的模型训练并进行戴口罩识别检测)
-
详解tensorflow训练自己的数据集实现CNN图像分类
-
pytorch + visdom CNN处理自建图片数据集的方法
-
pytorch 把MNIST数据集转换成图片和txt的方法
-
pytorch实现建立自己的数据集(以mnist为例)
-
深度学习 从零开始 —— 神经网络数学基础(一),学习Keras库的使用,神经网络简单流程,MNIST数据集使用