欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

数据增强之CutMix

程序员文章站 2024-03-19 23:02:04
...

关于CutMix

  1. CutMix是将随机图像的一个矩形部分剪切下来,然后将其粘贴到相关图像的相同位置;
  2. lambda决定了矩形的大小,其由参数为alpha的对称分布产生;
  3. 一个随机的(x, y)坐标是由均匀分布产生的,高度和宽度都有较大的限制。这个坐标就是要切割的矩形部分的中心;
  4. 然后,通过在中心"x"坐标上减去和加上长度/2,并在中心“y”减去和加上宽度/2,得到边界坐标。因此有四个坐标,即(bbx1,cy),(bbx2,cy), (bby1, cx), (bby2, cx),如此,便产生了一个要切割的矩形部分。

python代码

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha):
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.clip(np.random.beta(alpha, alpha),0.3,0.4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    new_data = data.clone()
    new_data[:, :, bby1:bby2, bbx1:bbx2] = data[indices, :, bby1:bby2, bbx1:bbx2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))
    targets = (target, shuffled_target, lam)

    return new_data, targets

数据增强之CutMix

Reference

https://arxiv.org/abs/1905.04899