欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow学习-MNIST数据集CNN

程序员文章站 2024-03-07 21:49:39
...

Tensorflow学习-MNIST数据集CNN

CNN

①数据集导入,keras自带的下载或者从某盘提取点击获取数据集,提取码:45yf

#加载MNIST数据集
from keras.datasets import mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data('E:/TensorFlow_mnist/MNIST_data/mnist.npz')

print(x_train.shape,type(x_train))  #60000张28*28的图片
print(y_train.shape,type(y_train))  #60000个标签

②图像和数据类型的转化
这里使用通道的方式进行数据类型转化,channels_first(batch,channels,height,width)

#数据处理:规范化
from keras import backend as K

img_rows,img_cols=28,28;

#channels_first(batch,channels,height,width)

if K.image_data_format()=='channels_first':
    x_train = x_train.reshape(x_train.shape[0],1,img_rows,img_cols)
    x_test = x_test.reshape(x_test.shap[0],1,img_rows,img_cols)
    input_shape = (1,img_rows,img_cols)
else:
    x_train =x_train.reshape(x_train.shape[0],img_rows,img_cols,1)
    x_test=x_test.reshape(x_test.shape[0],img_rows,img_cols,1)
    input_shape=(img_rows,img_cols,1)

print(x_train.shape,type(x_train))
print(x_test.shape,type(x_test))

#将数据类型转换为float32
X_train = x_train.astype('float32')
X_test = x_test.astype('float32')
#数据归一化
X_train /=255
X_test /=255

print(X_train.shape[0],'train samples')
print(X_test.shape[0],'test samples')

③统计训练数据集中各标签的数量并可视化展示

# 统计训练数据中的各标签数量
import numpy as np
import matplotlib.pyplot as plt

label, count = np.unique(y_train, return_counts=True)
print(label, count)
# lable的可视化输出
fig = plt.figure()
plt.bar(label, count, width=0.7, align='center')
plt.title("Label Distribution")
plt.xlabel("Label")
plt.ylabel("Count")
plt.xticks(label)
plt.ylim(0, 7500)
for a, b in zip(label, count):
    plt.text(a, b, '%d' % b, ha='center', va='bottom', fontsize=10)

plt.show()

Tensorflow学习-MNIST数据集CNN
④标签编码
one-hot编码的实现

#one-hot编码
from keras.utils import  np_utils

n_classes=10
print("Shape before one-hot encoding: ",y_train.shape)
Y_train=np_utils.to_categorical(y_train,n_classes)
print("Shape after one-hot encoding: ",Y_train.shape)
Y_test = np_utils.to_categorical(y_test,n_classes)

print(y_train[1])
print(Y_train[1])

结果显示:

0  #编码前的标签
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] #编码后0的标签

⑤使用Kears sequential model定义MNIST CNN网络

model=Sequential()
##Feature Extraction 特征提取
# 第1层卷积,32个3x3的卷积层,**函数使用relu
model.add(Conv2D(filters=32,kernel_size=(3,3),activation='relu',input_shape=input_shape))
# 第2层卷积,64个3x3的卷积层,**函数使用relu
model.add(Conv2D(filters=64,kernel_size=(3,3),activation='relu'))
#最大池化层,池化窗口2x2
model.add(MaxPooling2D(pool_size=(2,2)))
#Dropout 25% 的输入神经元
model.add(Dropout(0.25))

#Pooled featuremap 摊平后输入全连接网络
model.add(Flatten())

##Classification
#全连接层
model.add(Dense(128,activation='relu'))

#Dropout 50% 的输入神经元
model.add(Dropout(0.5))

#使用softmax**函数做多分类,输出各数字的概率
model.add(Dense(n_classes,activation='softmax'))
#查看模型结构
model.summary()

查看MNIST CNN网络结构:
Tensorflow学习-MNIST数据集CNN
每个网络层输出的形状

for layer in model.layers:
    print(layer.get_output_at(0).get_shape().as_list())

结果:
Tensorflow学习-MNIST数据集CNN
⑥编译模型

#编译模型
model.compile(loss='categorical_crossentropy',metrics=['accuracy'],optimizer='adam')

#训练模型,并将指标保存到history中
history = model.fit(X_train,
                    Y_train,
                    batch_size=128,
                    epochs=5,
                    verbose=2,
                    validation_data=(X_test,Y_test))

训练结果:
Tensorflow学习-MNIST数据集CNN
⑥保存模型并加载

#保存模型
import os
import tensorflow.gfile as gfile

save_dir='./mnist/model/'

if gfile.Exists(save_dir):
    gfile.DeleteRecursively(save_dir)
gfile.MakeDirs(save_dir)

model_name = 'keras_mnist.h5'
model_path = os.path.join(save_dir,model_name)
model.save(model_path)
print('Saved trained model at %s' % model_path)

#加载模型
from keras.models import load_model
mnist_model = load_model(model_path)

#统计模型在测试集上的分类结果
loss_and_metrics = mnist_model.evaluate(X_test,Y_test,verbose=2)

print("Test Loss:{}".format(loss_and_metrics[0]))
print("Test Accuracy:{}%".format(loss_and_metrics[1]*100))

predicted_classes=mnist_model.predict_classes(X_test)

correct_indices=np.nonzero(predicted_classes==y_test)[0]
incorrect_indices = np.nonzero(predicted_classes!=y_test)[0]
print("Classified correctly count:{}".format(len(correct_indices)))
print("Classified incorrectly count:{}".format(len(incorrect_indices)))

测试结果:
Tensorflow学习-MNIST数据集CNN

相关标签: TensorFlow学习