简洁明了的tensorflow2.0教程——用keras实现mnist数字识别
通过本文你可以快速学会使用keras搭建神经网络,只需40行代码构建神经网络实现mnist数据集手写数字的识别。MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集.,图片大小为28*28。完整代码在我的github,链接:https://github.com/JohnLeek/Tensorflow-study,仓库中mnist.npz就是数据集,day3_mnist_reg.py和day3_mnist_train_ex4.py为完整代码,觉得不错的github给个star吧。
Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。能够以最小的时延把你的想法转换为实验结果,是做好研究的关键。
查看相关说明,请访问https://keras-zh.readthedocs.io/。
Keras 兼容的 Python 版本: Python 2.7-3.6。
目前tensorflow已经将keras最为标准的后端库,这一特点在tf2.0中尤为明显,以下我们要讲的keras默认为tensorflow中的keras模块。
C:compile,在运行我们的模型前设置优化器,损失函数等
tensorflow提供了数据集,但是需要我们下载,有个问题就是下载速度太慢,这里我整理好了数据集,放到了我的github,下载好数据集之后,放在C盘,user/.keras/datasets,文件夹下即可,如图
mnist = tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train,x_test = x_train/255.,x_test/255.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation = "relu"),
tf.keras.layers.Dense(10,activation = "softmax")
])
这里我们首先拉着了神经网络,利用到了Flatten函数,然后搭建了一个128个输入节点,**函数为relu的输入层,然后因为我们要实现0~9数字分类,我们输出层为10个神经元,采用softmax使输出符合概率分布。
model.compile(optimizer = "adam",
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False),
metrics = ["sparse_categorical_accuracy"])
这里我们指点了优化器(optimizer),损失函数(loss),神经网络准确率评估标准(metrics),要注意因为我们是执行分类任务并且采用了独热码,所以我们交叉熵损失函数。
model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),
validation_freq = 1)
这里我们指定训练集x_train,y_train,训练接一次喂入神经网络数据集大小为32,训练次数为5次,测试集为(x_test,y_test),每隔一轮验证准确率。
model.summary()
控制台打印出了我们神经网络模型,参数数量,准确率(98.47%)
1、为了方便我们在不同的设备上训练我们已经训练好的模型,我们引入了断点续训的功能,帮组我们更好的优化神经网络。我们只需要在代码中添加这两句代码即可。
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path+".index"):
print("-----------------load Data---------------")
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = checkpoint_save_path,
save_weights_only = True,
save_best_only = True
)
这里我们指定了模型保存路径: "./checkpoint/mnist.ckpt",如果模型存在我们就在原有的基础上继续训练我们的模型,如果没有我们在训练模型的时候保存参数,调用ModelCheckpoint函数,指定我们要保存的参数,这里我保存了权重,和最优结果。
到这里还没结束我们需要对model.fit做一定更改,加一个回调函数,如下:
history = model.fit(x_train,y_train,batch_size = 32,epochs = 5,validation_data = (x_test,y_test),
validation_freq = 1,callbacks = [cp_callback])
好了断点续训保存模型就完成了,接下来我们保存下神经网络可训练参数:
file = open("./weights_variables.txt","w")
for v in model.trainable_variables:
file.write(str(v.name)+"\n")
file.write(str(v.shape)+"\n")
file.write(str(v.numpy())+"\n")
file.close()