数据增强之RandAugment
paper: https://arxiv.org/pdf/1909.13719.pdf
code: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
摘要
近期研究表明,数据增广可以显著提高深度学习的范化性能,尤其是在图像分类和目标检测方面均取得了不错的成果。尽管这些策略主要目的是为了提升精度,与此同时,在半监督机制下因为对原有数据集进行了扩充,从而增加了数据集的鲁棒性。
常见的图像识别任务中,增广的过程一般都是作为预处理阶段的任务之一。往往由于数据集过大而造成极大的计算损耗障碍。此外,由于所处的阶段不同,这些方法无法像模型算法一样随意调整正则化强度(尽管数据增广的效果直接取决于模型和数据集的大小)。
传统自动数据增广的策略通常是基于小数据集在轻量模型上训练后再应用于更大的模型。这就在策略上造成了一定的限制约束。本文则解决了这两大限制。RandAugment可以将数据增广所产生的增量样本空间大大缩小,从而使其可与模型训练过程捆绑在一起完成,避免将其作为独立的预处理任务来完成。此外,本文设置了增广强度的正则化参数,可以根据不同的模型和数据集大小进行调整。RandAugment方法可以作为外置工具作用于不同的图像处理任务、数据集工作中。
在CIFAR-10/100、SVHN和ImageNet数据集上能持平甚至优于先前的自动数据增广方法性能。在ImageNet数据集上,Baseline采用EfficientNet-B7结构的精度为84%,而AutoAugment+Baseline的精度为84.4%,本文的RandAugment+Baseline则达到了85.0%的准确率,分别提升了1和0.6个百分点。在目标检测方面,Baseline采用ResNet结构,添加RandAugment的效果较Baseline和其他增广方法提高了1.0~1.3个百分点。在COCO数据集上的表现也有0~0.3%mAP的效果提升。最后,由于本文超参数的可解释性,RandAugment可以用来研究数据作用与模型、数据集大小之间的关系。
RandAugment
考虑到以往数据增强方法都包含30多个参数,团队也将关注点转移到了如何大幅减少数据增强的参数空间。
为了减少参数空间的同时保持数据(图像)的多样性,研究人员用无参数过程替代了学习的策略和概率。
这些策略和概率适用于每次变换(transformation),该过程始终选择均匀概率为1/k的变换。
也就是说,给定训练图像的N个变换,RandAugment就能表示KN个潜在策略。
最后,需要考虑到的一组参数是每个增强失真(augmentation distortion)的大小。
研究人员采用线性标度来表示每个转换的强度。简单来说,就是每次变换都在0到10的整数范围内,其中,10表示给定变换的最大范围。
并假设一个单一的全局失真M(global distortion M)可能就足以对所有转换进行参数化。
这样,生成的算法便包含两个参数N和M,还可以用两行Python代码简单表示:
因为这两个参数都是可人为解释的,所以N和M的值越大,正则化强度就越大。
可以使用标准方法高效地进行超参数优化,但是考虑到极小的搜索空间,研究人员发现朴素网格搜索(naive grid search)是非常有效的。
实验结果
代码
import cv2
import numpy as np
import cv2
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=2):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0: return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor=5):
'''
same output as PIL.ImageEnhance.Color
'''
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor=5):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor=2):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor=2):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset=10, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
# 'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
# 'ShearX': shear_x_func,
# 'TranslateX': translate_x_func,
# 'TranslateY': translate_y_func,
'Posterize': posterize_func,
# 'ShearY': shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
# 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
# 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
# 'TranslateX': translate_level_to_args(
# translate_const, MAX_LEVEL, replace_value
# ),
# 'TranslateY': translate_level_to_args(
# translate_const, MAX_LEVEL, replace_value
# ),
'Posterize': posterize_level_to_args(MAX_LEVEL),
# 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10):
self.N = N
self.M = M
def get_random_ops(self):
sampled_ops = np.random.choice(list(func_dict.keys()), self.N)
return [(op, 1., self.M) for op in sampled_ops]
def __call__(self, img):
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
# img = cutout_func(img, 16, replace_value)
return img
if __name__ == '__main__':
import matplotlib.pyplot as plt
a = RandomAugment()
# img = np.random.randn(32, 32, 3)
# a(img)
imgPath = 'Subset1_img__3.png'
img = cv2.imread(imgPath)
img2 = a(img)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(img2)
plt.show()
Reference
上一篇: 用python实现排序二叉树(中序)
推荐阅读
-
opencv_python实现批量图片颜色变换,可用于数据增强
-
数据增强之RandAugment
-
用于医学图像分割的数据增强方法 —— 标准 imgaug 库的使用方法
-
imgaug数据增强库的环境配置
-
【数据库之SQL复杂查询】SQL复杂查询基本语法
-
数据结构之查找(二)--斐波那契查找
-
学习Python之路之Python中常见的数据结构一——列表
-
学习Python之路之Python中常见的数据结构三——字典
-
Hadoop学习总结之四:Map-Reduce的过程解析 博客分类: Hadoop学习总结 HadoopJVM数据结构Blog
-
Hadoop学习总结之四:Map-Reduce的过程解析 博客分类: Hadoop学习总结 HadoopJVM数据结构Blog