欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

BP神经网络基于TensorFlow的mnist数据集分类

程序员文章站 2022-07-13 11:28:01
...
# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("mnist/", one_hot=True)

# 定义回归模型
X = tf.placeholder(tf.float32, [None, 784]) # 输入的X的值
Y = tf.placeholder(tf.float32, [None, 10])  # 输出真实值


def create_model():
    w1 = tf.Variable(tf.random_uniform([784, 1024], -1, 1))
    b1 = tf.Variable(tf.random_uniform([1024], -1, 1))
    y1 = tf.sigmoid(tf.matmul(X, w1) + b1)
    y1 = tf.nn.dropout(y1, keep_prob=1)

    w2 = tf.Variable(tf.truncated_normal([1024, 512]))
    b2 = tf.Variable(tf.truncated_normal([512]))
    y2 = tf.sigmoid(tf.matmul(y1, w2) + b2)
    y2 = tf.nn.dropout(y2, keep_prob=1)

    w3 = tf.Variable(tf.random_uniform([512, 128], -1, 1))
    b3 = tf.Variable(tf.random_uniform([128], -1, 1))
    y3 = tf.sigmoid(tf.matmul(y2, w3) + b3)
    y3 = tf.nn.dropout(y3, keep_prob=1)

    w4 = tf.Variable(tf.truncated_normal([128, 10]))
    b4 = tf.Variable(tf.truncated_normal([10]))
    outputs = tf.matmul(y3, w4) + b4

    # softmax loss
    cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=outputs, labels=Y))
    #train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy_loss)
    train_op = tf.train.AdamOptimizer(0.001).minimize(cross_entropy_loss)

    #预测、准确率
    pred = tf.equal(tf.argmax(outputs, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(pred, tf.float32))

    return train_op, cross_entropy_loss, accuracy


with tf.Session() as sess:
    train_op1, cross_entropy_loss1, accuracy1 = create_model()

    sess.run(tf.global_variables_initializer())

    for i in range(10000000):
        xs, ys = mnist.train.next_batch(100)
        _, loss_, acc_ = sess.run([train_op1, cross_entropy_loss1, accuracy1], feed_dict={X: xs, Y: ys})
        if i % 100 != 0:
            continue
        # 测试当前模型在训练数据、测试数据、验证数据中的准确率,
        # 数据较少,验证所有的数据
        print("step:%s, loss: %s, train_acc:%s, test_acc:%s, valid_acc:%s" % (i, loss_,
                sess.run(accuracy1, feed_dict={X: mnist.train.images, Y: mnist.train.labels}),
                sess.run(accuracy1, feed_dict={X: mnist.test.images, Y: mnist.test.labels}),
                sess.run(accuracy1, feed_dict={X: mnist.validation.images, Y: mnist.validation.labels})))