tensorflow mnist 数据下载和学习
程序员文章站
2024-03-14 21:52:23
...
配置号了tensorflow的环境后,最想的就是跑跑程序,google的minist数据很小,方便我们跑和写一些简单的demo,下面讲解下数据的下载和操作
下载
1、导入input模块from tensorflow.examples.tutorials.mnist import input_data
2、input_data.read_data_sets("data/", one_hot= True)
操作
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
print("Dowload and Extract MNIST dataset")
# 第一个参数为下载的路径,one_hot 是二进制读入的方式
minist = input_data.read_data_sets("data/", one_hot= True)
print("type of minist is %s"%(type(minist)))
print("number of train data is %d"%(minist.train.num_examples))
print("number of test data is %d"%(minist.test.num_examples))
"""
打印结果
Dowload and Extract MNIST dataset
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
type of minist is <class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
number of train data is 55000
number of test data is 10000
"""
操作
input_data
已经封装了数据的一些操作。首先我们看下里面包含哪些数据
1、训练数据 55000训练样本 784 为28 * 28 的像素值
2、训练数据label 10个种类
3、测试数据 10000
4、测试数据label 10个种类
# whate does the data of MNIST look like?
# print("whate does the data of MNIST look like?")
# 包含训练数据 label 测试数据 label
trainimg = minist.train.images
trainlabel = minist.train.labels
testimg = minist.test.images
testlable = minist.test.labels
print("type of trainimg is %s" % (type(trainimg)))
print("type of trainlabel is %s" % (type(trainlabel)))
print("type of testimg is %s" % (type(testimg)))
print("type of testlable is %s" % (type(testlable)))
print(trainimg.shape)
# shape 后不加, 就会报错:TypeError: not all arguments converted during string formatting
# print("shape of trainimg is %s" % (trainimg.shape))
print("shape of trainimg is %s" % (trainimg.shape,))
print("shape of trainlabel is %s" % (trainlabel.shape,))
print("shape of testimg is %s" % (testimg.shape,))
print("shape of testlable is %s" % (testlable.shape,))
# 打印一个lable,第几个为1就表示为第几类
print(trainlabel[0, :])
"""
打印结果:
type of trainimg is <class 'numpy.ndarray'>
type of testimg is <class 'numpy.ndarray'>
type of testlable is <class 'numpy.ndarray'>
(55000, 784)
shape of trainimg is (55000, 784) # 55000训练样本 784 为28 * 28 的像素值
shape of trainlabel is (55000, 10) # 表示10个种类
shape of testimg is (10000, 784)
shape of testlable is (10000, 10)
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
"""
查看下测试图片:从shape来看,是将每张图展开成 一个array size 784,因此需要reshape成为图片28 *28:
# 查看的图片个数
num_showimage = 5
# 选取shape[0] = 55000 中随机5张
randidx = np.random.randint(trainimg.shape[0], size=num_showimage)
for i in randidx:
curr_img = np.reshape(trainimg[i, :], (28, 28))
curr_lable = np.argmax(trainlabel[i, :])
plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
plt.title("" + str(i) + " the Training data"+ "lable is " + str(curr_lable))
print("" + str(i) + "the training data" + "label is " + str(curr_lable))
plt.show()
可以看到显示数字图片,为28 *28的数字图片,这里就只显示一个得了:
对input_data对mnist进行了封装,可以很方便的访问和设置一些参数。
batch_size = 100
batch_xs, batch_ys = minist.train.next_batch(batch_size)
上一篇: 动手学深度学习(PyTorch实现)(一)--线性回归
下一篇: CTC loss 理解
推荐阅读
-
Tensorflow 入门 的自整理的MNIST简单网络和复杂网络练习
-
tensorflow mnist 数据下载和学习
-
[深度学习-实践]BP神经网络的Helloworld(手写体识别和Fashion_mnist)
-
【深度学习】LeNet卷积神经网络(MNIST 计算机视觉数据集)
-
Tensorflow 实战Google深度学习框架——学习笔记(六)LeNet-5网络实现MNIST手写数字集识别
-
dva学习--connect数据(model和router)
-
Angular4学习笔记(五)- 数据绑定、响应式编程和管道
-
TensorFlow(3) MNIST数据集分类简单版本(神经网络:一个输入层,一个输出层)
-
MNIST数据集学习
-
(Tensorflow之二十二)将mnist数据转变成图片格式