tensorflow mnist 数据下载和学习

1、导入input模块from tensorflow.examples.tutorials.mnist import input_data
2、input_data.read_data_sets("data/", one_hot= True)操作

from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt

print("Dowload and Extract MNIST dataset")
# 第一个参数为下载的路径,one_hot 是二进制读入的方式
minist = input_data.read_data_sets("data/", one_hot= True)
print("type of minist is %s"%(type(minist)))
print("number of train data is %d"%(minist.train.num_examples))
print("number of test data is %d"%(minist.test.num_examples))

Dowload and Extract MNIST dataset
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
type of minist is <class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
number of train data is 55000
number of test data is 10000


input_data 已经封装了数据的一些操作。首先我们看下里面包含哪些数据
1、训练数据 55000训练样本 784 为28 * 28 的像素值
2、训练数据label 10个种类
3、测试数据 10000
4、测试数据label 10个种类

# whate does the data of MNIST look like?
# print("whate does the data of MNIST look like?")

# 包含训练数据 label 测试数据 label
trainimg = minist.train.images
trainlabel = minist.train.labels
testimg = minist.test.images
testlable = minist.test.labels

print("type of trainimg is %s" % (type(trainimg)))
print("type of trainlabel is %s" % (type(trainlabel)))
print("type of testimg is %s" % (type(testimg)))
print("type of testlable is %s" % (type(testlable)))

# shape 后不加, 就会报错:TypeError: not all arguments converted during string formatting
# print("shape of trainimg is %s" % (trainimg.shape))
print("shape of trainimg is %s" % (trainimg.shape,))
print("shape of trainlabel is %s" % (trainlabel.shape,))
print("shape of testimg is %s" % (testimg.shape,))
print("shape of testlable is %s" % (testlable.shape,))

# 打印一个lable,第几个为1就表示为第几类
print(trainlabel[0, :])
        type of trainimg is <class 'numpy.ndarray'>
        type of testimg is <class 'numpy.ndarray'>
        type of testlable is <class 'numpy.ndarray'>
        (55000, 784)
        shape of trainimg is (55000, 784)  # 55000训练样本 784 为28 * 28 的像素值
        shape of trainlabel is (55000, 10) # 表示10个种类
        shape of testimg is (10000, 784)
        shape of testlable is (10000, 10)
        [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

查看下测试图片:从shape来看,是将每张图展开成 一个array size 784,因此需要reshape成为图片28 *28:

# 查看的图片个数
num_showimage = 5
# 选取shape[0] = 55000 中随机5张
randidx = np.random.randint(trainimg.shape[0], size=num_showimage)

for i in randidx:
    curr_img = np.reshape(trainimg[i, :], (28, 28))
    curr_lable = np.argmax(trainlabel[i, :])
    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
    plt.title("" + str(i) + " the Training data"+ "lable is " + str(curr_lable))
    print("" + str(i) + "the training data" + "label is " + str(curr_lable))

可以看到显示数字图片,为28 *28的数字图片,这里就只显示一个得了:
batch_size = 100
batch_xs, batch_ys = minist.train.next_batch(batch_size)
