卷积神经网络实例及代码(MNIST数据集介绍、下载及基本操作)
系列文章:
深度学习-卷积神经网络-实例及代码0.8—基于最小均方误差的线性判别函数参数拟合训练
1、MNIST数据集介绍
MNIST数据集是一个由手写数字图片构成的数据集,数字由0~9组成,图片大小为28*28
MNIST数据集包含训练集mnist.train和测试集mnist.test两部分
训练集mnist.train包含60000张图片,其中55000张训练用,5000张验证用
测试集mnist.test包含10000张图片,用于测试
2、MNIST数据集下载
第一种下载方式是直接去官网下载,Lecun MNIST数据集官方网址:
http://yann.lecun.com/exdb/mnist/
页面包含以下四个文件包:
train-images-idx3-ubyte.gz: training set images (9912422 bytes),对应训练集图片
train-labels-idx1-ubyte.gz: training set labels (28881 bytes),对应训练集分类标签
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes),对应测试集图片
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes),对应测试集分类标签
第二种下载方式是通过mnist数据集操作的python包,这个操作包已经包含在Tensorflow框架中
相关的Github项目地址:
tensorfow操作mnist数据集python包-Github
程序中导入input_data.py
执行mnist=input_data.read_data_sets('../../MNIST_data/', one_hot=True)读取数据集
如果在../../MNIST_data/路径下没有相关数据集,程序会自动下载MNIST数据集,如果已经有就不再下载了而直接读取数据
3、MNIST数据集基本操作代码
(1)获取MNIST数据集
(2)获取训练集/测试集图片数据及分类标签数据
(3)查看某个训练样本数据
(4)利用matplotlib.pyplot包图形化显示训练样本数据
代码实例项目Github地址(如果对你有所帮助,欢迎关注点赞~):
https://github.com/firemonkeygit/DeepLearningTensorflowMNIST
参考MnistTest.py文件,经过调试可用
注意导入操作mnist数据集的python包和配置正确的数据集路径
输出结果为:
(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)
[0. 0. 0. 0. 0. 0.
.....................(矩阵比较大,其他元素就不贴出来了) ]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
输出显示图片:(源代码解析见文末所附代码)
主要源代码文件MnistTest.py及说明如下:
from TensorflowTest.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
# MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签
mnist = input_data.read_data_sets('../../MNIST_data/', one_hot=True)
# load data
train_X = mnist.train.images
train_Y = mnist.train.labels
print(train_X.shape, train_Y.shape) # 输出训练集样本和标签的大小
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
# 查看数据,例如训练集中第一个样本的内容和标签
print(train_X[0]) # 是一个包含784个元素且值在[0,1]之间的向量
print(train_Y[0])
# 图形化显示样本,输出训练集中前4个样本
fig, ax = plt.subplots(nrows=2, ncols=2, sharex='all', sharey='all')
ax = ax.flatten()
for i in range(4):
img = train_X[i].reshape(28, 28)
# ax[i].imshow(img,cmap='Greys')
ax[i].imshow(img)
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
本文地址:https://blog.csdn.net/firemonkeycs/article/details/108262991
上一篇: Android集成Bugly热修复