欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow2利用Cifar10数据集实现卷积神经网络

程序员文章站 2022-03-17 20:56:36
...

1. 导入所需的库

import tensorflow as tf
import matplotlib.pyplot as plt

for i in [tf]:
    print(i.__name__,": ",i.__version__,sep="")

输出:

tensorflow: 2.2.0

2. 下载并导入数据cifar10数据集

(train_images, train_labels),(test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
train_images, test_images = train_images/255.0, test_images/255.0

for i in [train_images, train_labels, test_images, test_labels]:
    print(i.shape)

输出:

(50000, 32, 32, 3)
(50000, 1)
(10000, 32, 32, 3)
(10000, 1)

3. 展示数据

class_names = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i],cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i][0]])
plt.show()

输出:

TensorFlow2利用Cifar10数据集实现卷积神经网络

4. 构建网络模型

将一定数量的卷积层和池化层堆叠在一起,最后加上全连接层,就形成了经典的卷积神经网络。

注意:卷积神经网络的输入格式为:[batch_size, height, width, channels]

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32,(3,3),activation="relu",input_shape=(32,32,3)))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(64,(3,3),activation="relu"))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Conv2D(64,(3,3),activation="relu"))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation="relu"))
model.add(tf.keras.layers.Dense(10))

model.summary()

输出:

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 64)                65600     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

5. 编译并训练模型

model.compile(optimizer="adam",
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=["accuracy"])

history = model.fit(train_images, train_labels, epochs=20,
                    validation_data=(test_images, test_labels))

输出:

Epoch 1/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.5518 - accuracy: 0.8050 - val_loss: 0.9361 - val_accuracy: 0.7016
Epoch 2/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.5153 - accuracy: 0.8189 - val_loss: 0.9369 - val_accuracy: 0.7082
Epoch 3/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.4857 - accuracy: 0.8282 - val_loss: 0.9817 - val_accuracy: 0.7086
Epoch 4/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.4478 - accuracy: 0.8412 - val_loss: 1.0160 - val_accuracy: 0.7084
Epoch 5/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.4194 - accuracy: 0.8496 - val_loss: 1.0763 - val_accuracy: 0.6965
Epoch 6/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.3917 - accuracy: 0.8594 - val_loss: 1.1418 - val_accuracy: 0.6911
Epoch 7/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.3643 - accuracy: 0.8701 - val_loss: 1.1781 - val_accuracy: 0.6978
Epoch 8/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.3469 - accuracy: 0.8761 - val_loss: 1.2179 - val_accuracy: 0.6875
Epoch 9/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.3228 - accuracy: 0.8845 - val_loss: 1.2916 - val_accuracy: 0.6938
Epoch 10/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.3012 - accuracy: 0.8906 - val_loss: 1.3111 - val_accuracy: 0.6903
Epoch 11/20
1563/1563 [==============================] - 7s 5ms/step - loss: 0.2833 - accuracy: 0.8979 - val_loss: 1.3923 - val_accuracy: 0.6916
Epoch 12/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.2600 - accuracy: 0.9057 - val_loss: 1.4629 - val_accuracy: 0.6849
Epoch 13/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.2488 - accuracy: 0.9084 - val_loss: 1.5491 - val_accuracy: 0.6860
Epoch 14/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.2367 - accuracy: 0.9132 - val_loss: 1.5958 - val_accuracy: 0.6779
Epoch 15/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.2197 - accuracy: 0.9207 - val_loss: 1.5827 - val_accuracy: 0.6848
Epoch 16/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.2120 - accuracy: 0.9240 - val_loss: 1.7106 - val_accuracy: 0.6784
Epoch 17/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.2015 - accuracy: 0.9275 - val_loss: 1.7326 - val_accuracy: 0.6799
Epoch 18/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.1852 - accuracy: 0.9324 - val_loss: 1.8658 - val_accuracy: 0.6809
Epoch 19/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.1849 - accuracy: 0.9322 - val_loss: 2.0047 - val_accuracy: 0.6689
Epoch 20/20
1563/1563 [==============================] - 8s 5ms/step - loss: 0.1734 - accuracy: 0.9380 - val_loss: 1.9625 - val_accuracy: 0.6820

6. 评估模型

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"],label="val_accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0.0,1])
plt.legend()

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)

输出:

313/313 - 1s - loss: 1.9625 - accuracy: 0.6820

TensorFlow2利用Cifar10数据集实现卷积神经网络

print(test_acc)

输出:

0.6865000128746033