欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习——TensorFlow之数字体识别流程

程序员文章站 2022-05-21 22:20:12
...
import tensorflow as tf
# 导入mnist数据集
# 分析mnist样本特点以及定义变量
# 构建模型
# 训练模型并输出中间状态参数
# 测试模型
# 保存模型
# 读取模型


# 导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets('MNIST_data/',one_hot=True)

# 分析图片的特点,定义变量
x=tf.placeholder(tf.float32,shape=[None,784])
y=tf.placeholder(tf.float32,shape=[None,10])

# 构建模型
W=tf.Variable(tf.zeros([784,10]))

b=tf.Variable(tf.zeros([10]))

# z表示证据
z=tf.matmul(x,W)+b
# pred表示是每个数字的可能
pred=tf.nn.softmax(z)
# 损失函数,交叉熵,定义反向传播的结构
loss=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))

learn_rate=0.01

# 优化器,梯度下降法
optimizer=tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)

# 训练次数
epochs=25

# 批次大小
batch_size=100

# 把中间具体信息显示出来
display_step=1

with tf.Session() as sess:
    # 初始化全局变量
    sess.run(tf.global_variables_initializer())
    # 开始训练
    for epoch in range(epochs):
        # 取值大小
        avg_loss=0
        total_loss=0
        total_batch=int(mnist.train.images.shape[0]/batch_size)
        for i in range(total_batch):
            # 从数据集中按照batch_size大小取值
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            # 运行优化器
            _,c=sess.run([optimizer,loss],feed_dict={x:batch_xs,y:batch_ys})
            # 计算损失值得平均值
            total_loss+=c
        avg_loss=total_loss/total_batch
        if((epoch+1)%display_step==0):
            print('Epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_loss))
    print('########################Finished!#############################\n')
    # 测试模型
    print('########################Begin Test############################\n')
    correct_predict=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    accuracy=tf.reduce_mean(tf.cast(correct_predict,tf.float32))
    print('Accuracy:',sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
    print('########################Save Model############################\n')
    saver= tf.train.Saver()
    save_path='log/'
    saver.save(sess,save_path)
    print('saved Successfully at :',save_path)
# 保存模型


机器学习——TensorFlow之数字体识别流程
机器学习——TensorFlow之数字体识别流程

相关标签: Machine Learning