BP神经网络基于TensorFlow的mnist数据集分类
程序员文章站
2022-07-13 11:28:01
...
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist/", one_hot=True)
# 定义回归模型
X = tf.placeholder(tf.float32, [None, 784]) # 输入的X的值
Y = tf.placeholder(tf.float32, [None, 10]) # 输出真实值
def create_model():
w1 = tf.Variable(tf.random_uniform([784, 1024], -1, 1))
b1 = tf.Variable(tf.random_uniform([1024], -1, 1))
y1 = tf.sigmoid(tf.matmul(X, w1) + b1)
y1 = tf.nn.dropout(y1, keep_prob=1)
w2 = tf.Variable(tf.truncated_normal([1024, 512]))
b2 = tf.Variable(tf.truncated_normal([512]))
y2 = tf.sigmoid(tf.matmul(y1, w2) + b2)
y2 = tf.nn.dropout(y2, keep_prob=1)
w3 = tf.Variable(tf.random_uniform([512, 128], -1, 1))
b3 = tf.Variable(tf.random_uniform([128], -1, 1))
y3 = tf.sigmoid(tf.matmul(y2, w3) + b3)
y3 = tf.nn.dropout(y3, keep_prob=1)
w4 = tf.Variable(tf.truncated_normal([128, 10]))
b4 = tf.Variable(tf.truncated_normal([10]))
outputs = tf.matmul(y3, w4) + b4
# softmax loss
cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=outputs, labels=Y))
#train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy_loss)
train_op = tf.train.AdamOptimizer(0.001).minimize(cross_entropy_loss)
#预测、准确率
pred = tf.equal(tf.argmax(outputs, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(pred, tf.float32))
return train_op, cross_entropy_loss, accuracy
with tf.Session() as sess:
train_op1, cross_entropy_loss1, accuracy1 = create_model()
sess.run(tf.global_variables_initializer())
for i in range(10000000):
xs, ys = mnist.train.next_batch(100)
_, loss_, acc_ = sess.run([train_op1, cross_entropy_loss1, accuracy1], feed_dict={X: xs, Y: ys})
if i % 100 != 0:
continue
# 测试当前模型在训练数据、测试数据、验证数据中的准确率,
# 数据较少,验证所有的数据
print("step:%s, loss: %s, train_acc:%s, test_acc:%s, valid_acc:%s" % (i, loss_,
sess.run(accuracy1, feed_dict={X: mnist.train.images, Y: mnist.train.labels}),
sess.run(accuracy1, feed_dict={X: mnist.test.images, Y: mnist.test.labels}),
sess.run(accuracy1, feed_dict={X: mnist.validation.images, Y: mnist.validation.labels})))
上一篇: DNN训练过程中的一些问题以及技巧
下一篇: 斯坦福大学CS231课程笔记1
推荐阅读
-
详解tensorflow训练自己的数据集实现CNN图像分类
-
深度学习 从零开始 —— 神经网络数学基础(一),学习Keras库的使用,神经网络简单流程,MNIST数据集使用
-
BP神经网络基于TensorFlow的mnist数据集分类
-
TensorFlow系列(4)——基于MNIST数据集的CNN实现
-
tensorflow 优化 MNIST数据集分类 Jupyter
-
基于神经网络的人脸识别tensorflow(数据的存储与加载)
-
基于python的BP神经网络算法对mnist数据集的识别--批量处理版
-
GCN实战深入浅出图神经网络第五章:基于Cora数据集的GCN节点分类 代码分析
-
TensorFlow2利用猫狗数据集(cats_and_dogs_filtered.zip)实现卷积神经网络完成分类任务
-
Keras : 利用卷积神经网络CNN对图像进行分类,以mnist数据集为例建立模型并预测