欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow mnist 数据下载和学习

程序员文章站 2024-03-14 21:52:23
...

配置号了tensorflow的环境后,最想的就是跑跑程序,google的minist数据很小,方便我们跑和写一些简单的demo,下面讲解下数据的下载和操作

下载

1、导入input模块from tensorflow.examples.tutorials.mnist import input_data
2、input_data.read_data_sets("data/", one_hot= True)操作

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt

print("Dowload and Extract MNIST dataset")
# 第一个参数为下载的路径,one_hot 是二进制读入的方式
minist = input_data.read_data_sets("data/", one_hot= True)
print("type of minist is %s"%(type(minist)))
print("number of train data is %d"%(minist.train.num_examples))
print("number of test data is %d"%(minist.test.num_examples))

"""
打印结果
Dowload and Extract MNIST dataset
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
type of minist is <class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
number of train data is 55000
number of test data is 10000
"""

操作

input_data 已经封装了数据的一些操作。首先我们看下里面包含哪些数据
1、训练数据 55000训练样本 784 为28 * 28 的像素值
2、训练数据label 10个种类
3、测试数据 10000
4、测试数据label 10个种类

# whate does the data of MNIST look like?
# print("whate does the data of MNIST look like?")

# 包含训练数据 label 测试数据 label
trainimg = minist.train.images
trainlabel = minist.train.labels
testimg = minist.test.images
testlable = minist.test.labels

print("type of trainimg is %s" % (type(trainimg)))
print("type of trainlabel is %s" % (type(trainlabel)))
print("type of testimg is %s" % (type(testimg)))
print("type of testlable is %s" % (type(testlable)))

print(trainimg.shape)
# shape 后不加, 就会报错:TypeError: not all arguments converted during string formatting
# print("shape of trainimg is %s" % (trainimg.shape))
print("shape of trainimg is %s" % (trainimg.shape,))
print("shape of trainlabel is %s" % (trainlabel.shape,))
print("shape of testimg is %s" % (testimg.shape,))
print("shape of testlable is %s" % (testlable.shape,))

# 打印一个lable,第几个为1就表示为第几类
print(trainlabel[0, :])
"""
打印结果:
        type of trainimg is <class 'numpy.ndarray'>
        type of testimg is <class 'numpy.ndarray'>
        type of testlable is <class 'numpy.ndarray'>
        (55000, 784)
        shape of trainimg is (55000, 784)  # 55000训练样本 784 为28 * 28 的像素值
        shape of trainlabel is (55000, 10) # 表示10个种类
        shape of testimg is (10000, 784)
        shape of testlable is (10000, 10)
        [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
"""

查看下测试图片:从shape来看,是将每张图展开成 一个array size 784,因此需要reshape成为图片28 *28:

# 查看的图片个数
num_showimage = 5
# 选取shape[0] = 55000 中随机5张
randidx = np.random.randint(trainimg.shape[0], size=num_showimage)

for i in randidx:
    curr_img = np.reshape(trainimg[i, :], (28, 28))
    curr_lable = np.argmax(trainlabel[i, :])
    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
    plt.title("" + str(i) + " the Training data"+ "lable is " + str(curr_lable))
    print("" + str(i) + "the training data" + "label is " + str(curr_lable))
    plt.show()

可以看到显示数字图片,为28 *28的数字图片,这里就只显示一个得了:
tensorflow mnist 数据下载和学习

对input_data对mnist进行了封装,可以很方便的访问和设置一些参数。

batch_size = 100
batch_xs, batch_ys = minist.train.next_batch(batch_size)
相关标签: mnist tensorflow