Tensorflow2.0入门教程15:CNN网络添加BN层
程序员文章站
2023-12-21 14:11:10
...
BN(BatchNormalization)层的作用:
1,加速收敛;
2,控制过拟合,可以少用Dropout或者不用Dropout;
3,降低网络对初始化权重的不敏感;
4,允许使用比较大的学习率。
解决梯度消失与梯度爆炸的问题
- 1,网络中训练以batch_size为最小单位不断迭代,新的batch_size进入网络,就会产生新的γ与β,在BN层中,有总图片/batch_size组γ与β被保存。
- 2,图像卷积的过程中,通常使用多个卷积核,得到多张特征图,对于多个卷积核需要保存多个γ与β。
基于CNN的花卉识别练习
需要安装的包
pip install opencv-python
import cv2
import os
import tensorflow as tf
import numpy as np
一、读取数据并进行数据处理
数据集路径—根据自己的路径进行改写
path=r'flower_photos/'
由于我们的数据集图片大小不一致,所以需要resize成统一大小 这里resize成100x100x3
w=100
h=100
c=3
读取数据集图片并添加标签,最后的形式是data 对应图片, label 是标签,roses 0,daisy 1,sunflowers 2,tulips 3,dandelion 4.
def read_img(path):
imgs=[]
labels=[]
cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
for idx,i in enumerate(cate):
for j in os.listdir(i):
im = cv2.imread(i+'/'+j)
img = cv2.resize(im, (w, h))/255.
imgs.append(img)
labels.append(idx)
return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)
将数据集打乱顺序
num_example=data.shape[0] # data.shape是(3029, 100, 100, 3)
arr=np.arange(num_example)# 创建等差数组 0,1,...,3028
np.random.shuffle(arr)# 打乱顺序
data=data[arr]
label=label[arr]
print(label)
[1 4 4 ... 4 2 3]
标签one-hot处理
def to_one_hot(labels):
l = len(labels)
res = np.zeros((l, 5), dtype=np.float32)
for i in range(l):
res[i][labels[i]] = 1.
return res
label_oh = to_one_hot(label)
将所有数据集分为训练集80%、测试集20%
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label_oh[:s]
x_test=data[s:]
y_test=label_oh[s:]
x_train.shape
(2936, 100, 100, 3)
二、搭建网络
添加BN层:tf.keras.layers.BatchNormalization()
class CNN(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(
filters=32, # 卷积层神经元(卷积核)数目
kernel_size=[3, 3], # 感受野大小
padding='same', # padding策略(vaild 或 same)
)
# BN层
self.bn = tf.keras.layers.BatchNormalization()
self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=128, activation="relu")
self.dense2 = tf.keras.layers.Dense(units=5)
def call(self, inputs):
x = self.conv1(inputs)
x = self.bn(x)
x = tf.nn.relu(x)
x = self.pool1(x)
x = self.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
model = CNN()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)
Train on 2348 samples, validate on 588 samples
Epoch 1/10
2348/2348 [==============================] - 16s 7ms/sample - loss: 14.9535 - accuracy: 0.3271 - val_loss: 1.6025 - val_accuracy: 0.2041
Epoch 2/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 1.4817 - accuracy: 0.3296 - val_loss: 1.5745 - val_accuracy: 0.2483
Epoch 3/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 1.3477 - accuracy: 0.4404 - val_loss: 1.5366 - val_accuracy: 0.2993
Epoch 4/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 1.2345 - accuracy: 0.4791 - val_loss: 1.4772 - val_accuracy: 0.3452
Epoch 5/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 1.1583 - accuracy: 0.4911 - val_loss: 1.4653 - val_accuracy: 0.3520
Epoch 6/10
2348/2348 [==============================] - 16s 7ms/sample - loss: 1.0837 - accuracy: 0.5213 - val_loss: 1.4188 - val_accuracy: 0.4014
Epoch 7/10
2348/2348 [==============================] - 16s 7ms/sample - loss: 1.0094 - accuracy: 0.5503 - val_loss: 1.3899 - val_accuracy: 0.4082
Epoch 8/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 0.9346 - accuracy: 0.5839 - val_loss: 1.4172 - val_accuracy: 0.4354
Epoch 9/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 0.8827 - accuracy: 0.6060 - val_loss: 1.2974 - val_accuracy: 0.4830
Epoch 10/10
2348/2348 [==============================] - 17s 7ms/sample - loss: 0.8194 - accuracy: 0.6495 - val_loss: 1.2015 - val_accuracy: 0.5068
model.evaluate(x_test,y_test,verbose=2)
734/1 - 1s - loss: 1.1530 - accuracy: 0.5245
[1.1842144834897823, 0.52452314]