欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  科技

卷积神经网络实例及代码(MNIST数据集介绍、下载及基本操作)

程序员文章站 2022-03-05 09:37:59
1、MNIST数据集介绍MNIST数据集是一个由手写数字图片构成的数据集,数字由0~9组成,图片大小为28*28MNIST数据集包含训练集mnist.train和测试集mnist.test两部分训练集mnist.train包含60000张图片,其中55000张训练用,5000张验证用测试集mnist.test包含10000张图片,用于测试2、MNIST数据集下载第一种下载方式是直接去官网下载,Lecun MNIST数据集官方网址: ......

系列文章:

深度学习-卷积神经网络-实例及代码0.8—基于最小均方误差的线性判别函数参数拟合训练


1、MNIST数据集介绍

MNIST数据集是一个由手写数字图片构成的数据集,数字由0~9组成,图片大小为28*28

MNIST数据集包含训练集mnist.train和测试集mnist.test两部分

训练集mnist.train包含60000张图片,其中55000张训练用,5000张验证用

测试集mnist.test包含10000张图片,用于测试


2、MNIST数据集下载

第一种下载方式是直接去官网下载,Lecun MNIST数据集官方网址:

http://yann.lecun.com/exdb/mnist/

页面包含以下四个文件包:

train-images-idx3-ubyte.gz:  training set images (9912422 bytes),对应训练集图片
train-labels-idx1-ubyte.gz:  training set labels (28881 bytes),对应训练集分类标签
t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes),对应测试集图片
t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes),对应测试集分类标签

第二种下载方式是通过mnist数据集操作的python包,这个操作包已经包含在Tensorflow框架中

相关的Github项目地址:

tensorfow操作mnist数据集python包-Github

程序中导入input_data.py

执行mnist=input_data.read_data_sets('../../MNIST_data/', one_hot=True)读取数据集

如果在../../MNIST_data/路径下没有相关数据集,程序会自动下载MNIST数据集,如果已经有就不再下载了而直接读取数据


3、MNIST数据集基本操作代码

(1)获取MNIST数据集

(2)获取训练集/测试集图片数据及分类标签数据

(3)查看某个训练样本数据

(4)利用matplotlib.pyplot包图形化显示训练样本数据

代码实例项目Github地址(如果对你有所帮助,欢迎关注点赞~):

https://github.com/firemonkeygit/DeepLearningTensorflowMNIST

参考MnistTest.py文件,经过调试可用

注意导入操作mnist数据集的python包和配置正确的数据集路径

输出结果为:

(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)
[0.         0.         0.         0.         0.         0.
.....................(矩阵比较大,其他元素就不贴出来了) ]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

输出显示图片:(源代码解析见文末所附代码)

卷积神经网络实例及代码(MNIST数据集介绍、下载及基本操作)

主要源代码文件MnistTest.py及说明如下:

from TensorflowTest.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt

# MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签
mnist = input_data.read_data_sets('../../MNIST_data/', one_hot=True)

# load data
train_X = mnist.train.images
train_Y = mnist.train.labels
print(train_X.shape, train_Y.shape)   # 输出训练集样本和标签的大小
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)

# 查看数据,例如训练集中第一个样本的内容和标签
print(train_X[0])       # 是一个包含784个元素且值在[0,1]之间的向量
print(train_Y[0])

# 图形化显示样本,输出训练集中前4个样本
fig, ax = plt.subplots(nrows=2, ncols=2, sharex='all', sharey='all')
ax = ax.flatten()
for i in range(4):
    img = train_X[i].reshape(28, 28)
    # ax[i].imshow(img,cmap='Greys')
    ax[i].imshow(img)
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()

本文地址:https://blog.csdn.net/firemonkeycs/article/details/108262991