深度学习(一)——简单神经网络识别手写数字
程序员文章站
2022-06-28 11:11:58
...
MNIST数据集相当于深度学习中的“Hello World”,用于开始做测试用的简单的视觉数据集,由几万张28*28的手写数字组成,只包含灰度信息,分为十类0~9。
1.主要步骤
1)选择softmax regression模型
2)定义算是函数,这里选择交叉熵
3)选择优化算法梯度下降法
4)迭代进行数据训练
5)进行验证和准确率评测
2,代码
#加载数据
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
#训练集5.5w,测试机1w,验证集0.5w,每个样本有一个label
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
#label是一个10维的向量,[1,0,0,0,0,0,0,0,0,0],其代表数字为0
#placeholder数据输入的地方,第一个参数数据类型,第二个参数是数据尺寸大小
sess=tf.InteractiveSession()
x=tf.placeholder(tf.float32,[None,784])
#定义权重和截距b
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
#softmax regression算法公式
y=tf.nn.softmax(tf.matmul(x,W)+b)
#定义一个损失函数,多分类问题多用cross_entropy
y_=tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
#定义优化算法和训练速率,用梯度下降法和0.5的速率
train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#定义全局参数初始化器
tf.global_variables_initializer().run()
#进行训练
for i in range(1000):
batch_xs,batch_ys=mnist.train.next_batch(100)
train_step.run({x:batch_xs,y_:batch_ys})
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
准确率为:0.9156
参考:
《Tensorflow实战》
推荐阅读
-
深度学习 从零开始 —— 神经网络数学基础(一),学习Keras库的使用,神经网络简单流程,MNIST数据集使用
-
第一个机器学习算法:K-近邻算法实现手写数字识别系统
-
Tensorflow学习:循环(递归/记忆)神经网络RNN(手写数字识别:MNIST数据集分类)
-
pytorch 深度学习入门代码 (四)多层全连接神经网络实现 MNIST 手写数字分类
-
深度学习(一)——简单神经网络识别手写数字
-
python深度学习第三讲——用python写神经网络梯度下降(手写字符识别mnist)
-
利用BP神经网络 设计一个三层神经网络解决手写数字的识别问题
-
Python——numpy实现简单BP神经网络识别手写数字
-
【Tensorflow与深度学习笔记day07】5.2. ANN网络分析+Mnist手写数字识别+one-hot编码+SoftMax回归+损失计算-交叉熵损失+实现神经网络模型+模型正确率评估
-
简单python代码实现三层神经网络识别手写数字