数据增强之CutMix
程序员文章站
2024-03-19 23:02:04
...
关于CutMix
- CutMix是将随机图像的一个矩形部分剪切下来,然后将其粘贴到相关图像的相同位置;
- lambda决定了矩形的大小,其由参数为alpha的对称分布产生;
- 一个随机的(x, y)坐标是由均匀分布产生的,高度和宽度都有较大的限制。这个坐标就是要切割的矩形部分的中心;
- 然后,通过在中心"x"坐标上减去和加上长度/2,并在中心“y”减去和加上宽度/2,得到边界坐标。因此有四个坐标,即(bbx1,cy),(bbx2,cy), (bby1, cx), (bby2, cx),如此,便产生了一个要切割的矩形部分。
python代码
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix(data, target, alpha):
indices = torch.randperm(data.size(0))
shuffled_data = data[indices]
shuffled_target = target[indices]
lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
new_data = data.clone()
new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
targets = (target, shuffled_target, lam)
return new_data, targets