【TensorFlow】MNIST手写数字识别
程序员文章站
2024-01-22 11:55:10
...
MNIST
MNIST是一个非常简单的机器视觉数据集。如图,它由几万字28像素×28像素的手写数字组成,这些图片只包含灰度值信息。我们的任务是对这些手写数字的图片进行分类,转成0~9一共10类。
首先对MNIST数据进行加载,然后查看mnist这个数据集的情况。
# 输入程序
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)
# 运行结果
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)
one_hot 是指一个在大多数维度上为0的向量,在一维中是1。在这种情况下,第n个数字将被表示为一个在第n维中1的向量。如数字5对应[0,0,0,0,0,1,0,0,0,0]。
可以看到,训练集有55000个样本,测试集有10000个样本,同时验证集有5000个样本。每个样本都有对应的标签label。
实现原理
训练数据的特征是一个55000×784的Tensor,第一个维度是图片的编号,第二个维度是图片中像素点的编号,同时训练的数据Label是一个55000×10的Tensor,使用one-hot编码。
Softmax Regression:Softmax Regression是Logistic回归的推广,处理多分类问题。
损失函数(交叉熵)
其中,y是预测的概率分布,y’是真实的概率分布,通常用来判断模型对真实概率分布估计的准确程度。
程序
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
#下载并加载数据
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#数据与标签的占位
x = tf.placeholder(tf.float32,shape = [None,784])
y_actual = tf.placeholder(tf.float32,shape=[None,10])
#初始化权重和偏置
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
#softmax回归,得到预测概率
y_predict = tf.nn.softmax(tf.matmul(x,W) + b)
#求交叉熵得到残差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indices=1))
#梯度下降法使得残差最小,学习速率为0.01
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#测试阶段,测试准确度计算
correct_prediction = tf.equal(tf.argmax(y_predict,1),tf.argmax(y_actual,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))#多个批次的准确度均值
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#训练,迭代1000次
for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)#按批次训练,每批100行数据
sess.run(train_step,feed_dict={x:batch_xs,y_actual:batch_ys})#执行训练
if(i%100==0):#每训练100次,测试一次
print("accuracy:",sess.run(accuracy,feed_dict={x: mnist.test.images, y_actual: mnist.test.labels}))
运行结果
上一篇: ASP常用的系统配置函数