欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

(三)Tensorflow学习——mnist数据集简介

程序员文章站 2024-03-07 21:45:21
...

导入相关包

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# tensorflow自带的一些数据集
from tensorflow.examples.tutorials.mnist import input_data

加载数据集

在该目录下,建立一个空文件夹data,加载mnist数据集时,会自动从网上下载

print('Download and Extract MNIST dataset')
mnist = input_data.read_data_sets('data/', one_hot=True)
print('type of "mnist" is %s' % (type(mnist)))
print('number of train data is %d' % (mnist.train.num_examples))
print('number of test data is %d' % (mnist.test.num_examples))

(三)Tensorflow学习——mnist数据集简介
(三)Tensorflow学习——mnist数据集简介

mnist数据集的描述信息

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print('type of "trainimg" is %s' % (type(trainimg)))
print('type of "trainlabel" is %s' % (type(trainlabel)))
print('type of "testimg" is %s' % (type(testimg)))
print('type of "testlabel" is %s' % (type(testlabel)))
print('shape of "trainimg" is %s' % (trainimg.shape,))
print('shape of "trainlabel" is %s' % (trainlabel.shape,))
print('shape of "testimg" is %s' % (testimg.shape,))
print('shape of "testlabel" is %s' % (testlabel.shape,))

输出结果:
(三)Tensorflow学习——mnist数据集简介

打印原数据集

nsample = 5
randidx = np.random.randint(trainimg.shape[0], size=nsample)

for i in randidx:
    curr_img = np.reshape(trainimg[i, :], (28, 28))
    curr_label = np.argmax(trainlabel[i, :])
    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
    plt.title('' + str(i) + 'th Training Data'
              + 'Label is ' + str(curr_label))
    print('' + str(i) + 'th Training Data'
              + 'Label is ' + str(curr_label))

这里只展示一张图片:
(三)Tensorflow学习——mnist数据集简介