针对cifar-10训练数据集的数据增强操作(附代码)
程序员文章站
2022-05-18 17:11:32
...
在深度学习中,扩充数据集也叫数据增强(Data Augmentation)
数据增强主要有两种方法
离线扩充:从根本上对数据集进行扩充,与训练模型代码分开操作
在线增强:在小批量数据集进入训练模型之前,进行图像变换操作,与训练模型代码同时操作
离线扩充的完整代码
import tensorflow as tf
from keras import datasets
from keras.preprocessing import image
import matplotlib.pyplot as plt
import numpy as np
# 在线数据集的加载
(x, y), (x_test, y_test) = datasets.cifar10.load_data()#(x,y)保存训练集数据,(x_test,y_test)保存测试集
print("训练集数据的图像和标签:",x.shape,y.shape)#查看Keras官方cifar-10的训练集的图像和标签:(50000,32,32,3) (10000,1)
print('x的4个维度:',x.shape[0],x.shape[1],x.shape[2],x.shape[3])#输出x的4个维度:50000,32,32,3
print("第一张图frog的储存格式: ",x[0].shape)#第一张图像:(32,32,3)
print("第二张图truck的储存格式: ",x[1].shape)#第二张图像:(32,32,3)
label_names=['airplane','automobile','brid','cat','deer','dog','frog','horse','ship','truck']
y = tf.squeeze(y,axis=1) #把1的那一维度去掉,很关键
#(x,y)是原始数据集
# (x_new,y_new)是新数据集
(x_new,y_new) = (x,y)
#选取开头的10张图像进行随机变换
for i in range(0,10):
print("第 %r 张图:%r"%(i,label_names[y[i]]))
x_image = image.array_to_img(x[i])#为了对称,这个没啥用,可以直接 plt.imshow(x[i])
#显示原始图像
plt.imshow(x_image)
plt.figure(i)
plt.show()
#对图像进行随机变换(调用tensorflow.image内置的图像变换函数)
xx_image=tf.image.random_flip_left_right(x[i]) #左右翻转
xx_image=tf.image.random_brightness(xx_image, max_delta=0.7)#调亮度
xx_image=tf.image.random_contrast(xx_image, lower=0.2, upper=1.8)#调对比度
#显示随机变换后得到图像
plt.imshow(xx_image)
plt.figure(i+1)#防止图像覆盖
plt.show()
#将变换的图像插入原始训练数据集
x_new=np.insert(x_new,0,xx_image,axis=0)#向原始训练集的 开头 插入随机变换后的图像(0表示在开头插入,1表示在第二个位置插入)
print(x_new.shape)
y_new=np.insert(y_new,0,y[i],axis=0)#向原始训练集的标签插入随机变换后的图像的标签(标签不变,但是插入位置相同)
print(y_new.shape)
print('---------------------------------------------------------')
print('------------------------数据集扩充完成--------------------')
np.savez("data_augment_new",images=x_new,labels=y_new)#保存多个数组到文件中,格式为npz(np.save()只能保存二维以下的数组)
new_data=np.load('data_augment_new.npz')#加载文件
print('新数据集数据存储格式: ',new_data.files)#'images' 'labels'
print("新数据集的图像的储存格式: ",new_data['images'].shape)#新数据集的图像
print("新数据及的标签的储存格式: ",new_data['labels'].shape)#新数据及的标签
xx = new_data['images']
yy = new_data['labels']
#把新的数据集的开头前10张打印出来
for k in range(0,10):
plt.imshow(xx[k])
plt.figure()
plt.show()
代码结果演示
原始数据集图像变换
新数据集新增图像
代码思路
导入原始数据集
在原始数据集上进行图像变换(代码里面选择数据集开头的 10 张图像操作)
把变换后的图像插入原始数据集(代码里面把新图像插入原始数据集的开头)
根据需要循环上两步的操作
保存更新后的数据集(np.savez)
读取保存的新数据集(np.load)
查看新数据集中插入的图像
大概就是这样,代码注释很详细了,respect!!!
在线增强简单举例
for epoch in range(500):#迭代500次
total_train_correct = 0
for step, (x, y) in enumerate(train_db):#enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列(数据下标和数据)
#在线增强举例
x = tf.image.random_flip_left_right(x)
x = tf.image.random_hue(x, max_delta=0.2)
x = tf.image.random_brightness(x, max_delta=0.7)
x = tf.image.random_contrast(x, lower=0.2, upper=1.8)
with tf.GradientTape() as tape:
logits = model(x)#输入模型
y_onehot = tf.one_hot(y, depth=10)#生成10个向量
loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)#使用交叉熵损失函数(from_logits=False,output为经过softmax输出的概率值。from_logits=True,output为经过网络直接输出的 logits张量。
loss = tf.reduce_mean(loss)#求损失的平均值
这里训练模型代码的节选部分:在每个小批量数据中进行图像随机变换
提示
以上只是举例想要说明数据增强的两种方法的两种简单的实现
这里的离线扩充选择的是随机图像变换,所以会出现图像重复的情况,可能对训练模型没有太大帮助
我个人测试后发现还是在线增强的方法有用一些,可以提升模型的测试精度
离线扩充对于小容量的数据集的数据增强比较有优势
在线增强对于一些大容量的数据集的数据增强比较有优势
在实际训练操作的时候要根据自己要求进行参数设置和更多图像变换的选择
有什么地方说的不对的欢迎批评指正,respect!