欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

程序员文章站 2024-03-07 22:16:03
...

通过本文你可以快速学会使用keras搭建神经网络,只需40行代码构建神经网络实现mnist数据集手写数字的识别。MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集.,图片大小为28*28。完整代码在我的github,链接:https://github.com/JohnLeek/Tensorflow-study,仓库中mnist.npz就是数据集,day3_mnist_reg.py和day3_mnist_train_ex4.py为完整代码,觉得不错的github给个star吧。

一、keras介绍

Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlowCNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。能够以最小的时延把你的想法转换为实验结果,是做好研究的关键。

如果你在以下情况下需要深度学习库,请使用 Keras:

  1. 允许简单而快速的原型设计(由于用户友好,高度模块化,可扩展性)。
  2. 同时支持卷积神经网络和循环神经网络,以及两者的组合。
  3. 在 CPU 和 GPU 上无缝运行。

查看相关说明,请访问https://keras-zh.readthedocs.io/

Keras 兼容的 Python 版本: Python 2.7-3.6

目前tensorflow已经将keras最为标准的后端库,这一特点在tf2.0中尤为明显,以下我们要讲的keras默认为tensorflow中的keras模块。

二、用keras搭建神经网络技巧

我总结了下边的几个关键字:

D:Data,加载数据集

S:Sequential,搭建我们的神经网络

C:compile,在运行我们的模型前设置优化器,损失函数等

F:fit,设置神经网络,传入训练集测试集,指定训练次数

S:summary,运行神经网络

三、代码实现

1、首先我们加载数据集(D)

tensorflow提供了数据集,但是需要我们下载,有个问题就是下载速度太慢,这里我整理好了数据集,放到了我的github,下载好数据集之后,放在C盘,user/.keras/datasets,文件夹下即可,如图

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

然后我们开始加载数据集

mnist = tf.keras.datasets.mnist

(x_train,y_train),(x_test,y_test) = mnist.load_data()

然后调整数据集,进行归一化操作,加快神经网络收敛速度

x_train,x_test = x_train/255.,x_test/255.

2、搭建我们的神经网络(S)

model = tf.keras.models.Sequential([

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(128,activation = "relu"),

    tf.keras.layers.Dense(10,activation = "softmax")

])

这里我们首先拉着了神经网络,利用到了Flatten函数,然后搭建了一个128个输入节点,**函数为relu的输入层,然后因为我们要实现0~9数字分类,我们输出层为10个神经元,采用softmax使输出符合概率分布。

3、设置神经网络参数(C)

model.compile(optimizer = "adam",

              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False),

              metrics = ["sparse_categorical_accuracy"])

这里我们指点了优化器(optimizer),损失函数(loss),神经网络准确率评估标准(metrics),要注意因为我们是执行分类任务并且采用了独热码,所以我们交叉熵损失函数。

4、设置神经网络数据集相关参数(F)

model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),

                    validation_freq = 1)

这里我们指定训练集x_train,y_train,训练接一次喂入神经网络数据集大小为32,训练次数为5次,测试集为(x_test,y_test),每隔一轮验证准确率。

5、运行我们的模型(S)

model.summary()

6、结果展示

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

控制台打印出了我们神经网络模型,参数数量,准确率(98.47%)

四、进阶操作保存我们的模型,加入断点续训,保存模型参数

1、为了方便我们在不同的设备上训练我们已经训练好的模型,我们引入了断点续训的功能,帮组我们更好的优化神经网络。我们只需要在代码中添加这两句代码即可。

checkpoint_save_path = "./checkpoint/mnist.ckpt"

if os.path.exists(checkpoint_save_path+".index"):

    print("-----------------load Data---------------")

    model.load_weights(checkpoint_save_path)


cp_callback = tf.keras.callbacks.ModelCheckpoint(

    filepath = checkpoint_save_path,

    save_weights_only = True,

    save_best_only = True

)

这里我们指定了模型保存路径: "./checkpoint/mnist.ckpt",如果模型存在我们就在原有的基础上继续训练我们的模型,如果没有我们在训练模型的时候保存参数,调用ModelCheckpoint函数,指定我们要保存的参数,这里我保存了权重,和最优结果。

到这里还没结束我们需要对model.fit做一定更改,加一个回调函数,如下:

history = model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),

                    validation_freq = 1,callbacks = [cp_callback])

好了断点续训保存模型就完成了,接下来我们保存下神经网络可训练参数:

file = open("./weights_variables.txt","w")

for v in model.trainable_variables:

    file.write(str(v.name)+"\n")

    file.write(str(v.shape)+"\n")

    file.write(str(v.numpy())+"\n")

file.close()

2、结果展示

这个就是我们保存好的模型。

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

这个就是我们神经网络所有可训练参数的值。

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

我们看一下第一次训练结果。

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

现在我们再运行下我们的代码。看看是不是在上一次训练的基础上继续训练。

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别

可以看到加载了我们保存好的模型,准确率不断提高

简洁明了的tensorflow2.0教程——用keras实现mnist数字识别