欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

数据增强之Gridmask

程序员文章站 2024-03-19 23:06:46
...

paper: https://arxiv.org/abs/2001.04086
code: https://github.com/Jia-Research-Lab/GridMask

概述

作者首先回顾了数据增强(Data augmentation)方法,指出当前方法有三类:spatial transformation, color distortion, 以及 information dropping。本文提出的方法属于 information dropping,作者指出,对于此类方法,避免过度删除或保持连续区域是核心问题:一方面,过度删除区域将造成完整目标被删除或者上下文信息缺失,因此,剩下的区域不足以表达目标信息,会成为noisy data。另一方面,保留过多区域,将会使得目标不受影响(untouched),会影响网络的鲁棒性。

作者重点介绍了 Cutout 和 HaS 方法。Cutout方法只删除图像中的一块连续区域,因此,容易出现删除掉整个目标,或者一点目标也没有删除的情况;HaS方法把图像划分为若干小块的区域,然后随机删除,但仍然会出现和 Cutout 相同的问题。下图展示了 GridMask 方法与当前方法的对比。

方法

GridMask 通过生成一个和原图相同分辨率的mask,然后将该mask与原图相乘得到一个图像。下图中灰色区域的值为1,黑色区域的值为0。这样,就实现了特定区域的 information dropping,本质上可以理解为一种正则化方法。
数据增强之Gridmask
GridMask对应4个参数,为 [x,y,r,d] ,四个参数的设置如下图所示:
数据增强之Gridmask
从图中可以看出, [r] 代表了保留原图像信息的比例,有一个计算方法,具体可以阅读论文。 [d] 决定了一个dropped square的大小, 参数 [x] 和 [y] 的取值有一定随机性,细节可以阅读论文。

实验结果

在ImageNet-1K图像分类任务上,Cutout对ResNet50的提升为0.6%,HaS的提升为0.7%,AutoAugement提升为1.1%,相比而言,GridMask的提升为1.4%。作者还在CIFAR10数据集上进行了实验,这里不再详述。

在Ablation Study中,作者首先分析了参数 [公式] 。如下图所示,在ImageNet-1K数据集上,设置为0.6比较好;在CIFAR10数据集上,设置为0.4比较好。作者解释为,在复杂的数据集上应该保持更多的信息来避免under-fitting,在简单数据集上应该丢弃更多的信息来减少over-fitting,这和 common sense 是一致的。
数据增强之Gridmask

代码

import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import pdb


class Grid(object):
    def __init__(self, use_h, use_w, rotate=1, offset=True, ratio=0.005, mode=0, prob=1.):
        self.use_h = use_h
        self.use_w = use_w
        self.rotate = rotate
        self.offset = offset
        self.ratio = ratio
        self.mode = mode
        self.st_prob = prob
        self.prob = prob

    def set_prob(self, epoch, max_epoch):
        self.prob = self.st_prob * epoch / max_epoch

    def __call__(self, img):
        if np.random.rand() > self.prob:
            return img
        h = img.size(1)
        w = img.size(2)
        self.d1 = 2
        self.d2 = min(h, w)
        hh = int(1.5 * h)
        ww = int(1.5 * w)
        d = np.random.randint(self.d1, self.d2)
        # d = self.d
        #        self.l = int(d*self.ratio+0.5)
        if self.ratio == 1:
            self.l = np.random.randint(1, d)
        else:
            self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
        mask = np.ones((hh, ww), np.float32)
        st_h = np.random.randint(d)
        st_w = np.random.randint(d)
        if self.use_h:
            for i in range(hh // d):
                s = d * i + st_h
                t = min(s + self.l, hh)
                mask[s:t, :] *= 0
        if self.use_w:
            for i in range(ww // d):
                s = d * i + st_w
                t = min(s + self.l, ww)
                mask[:, s:t] *= 0

        r = np.random.randint(self.rotate)
        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)
        mask = np.asarray(mask)
        #        mask = 1*(np.random.randint(0,3,[hh,ww])>0)
        mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2 + w]

        mask = torch.from_numpy(mask).float()
        if self.mode == 1:
            mask = 1 - mask

        mask = mask.expand_as(img)
        if self.offset:
            offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
            offset = (1 - mask) * offset
            img = img * mask + offset
        else:
            img = img * mask

        return img


class GridMask(nn.Module):
    def __init__(self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.):
        super(GridMask, self).__init__()
        self.use_h = use_h
        self.use_w = use_w
        self.rotate = rotate
        self.offset = offset
        self.ratio = ratio
        self.mode = mode
        self.st_prob = prob

    def set_prob(self, epoch, max_epoch):
        self.prob = self.st_prob * epoch / max_epoch  # + 1.#0.5

    def forward(self, x):
        if np.random.rand() > self.prob or not self.training:
            return x
        n, c, h, w = x.size()
        x = x.view(-1, h, w)
        hh = int(1.5 * h)
        ww = int(1.5 * w)
        d = np.random.randint(2, h)
        # d = self.d
        # self.l = int(d*self.ratio+0.5)
        self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
        mask = np.ones((hh, ww), np.float32)
        st_h = np.random.randint(d)
        st_w = np.random.randint(d)
        if self.use_h:
            for i in range(hh // d):
                s = d * i + st_h
                t = min(s + self.l, hh)
                mask[s:t, :] *= 0
        if self.use_w:
            for i in range(ww // d):
                s = d * i + st_w
                t = min(s + self.l, ww)
                mask[:, s:t] *= 0

        r = np.random.randint(self.rotate)
        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)
        mask = np.asarray(mask)
        #        mask = 1*(np.random.randint(0,3,[hh,ww])>0)
        mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2 + w]

        mask = torch.from_numpy(mask).float().cuda()
        if self.mode == 1:
            mask = 1 - mask
        mask = mask.expand_as(x)
        if self.offset:
            offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float().cuda()
            x = x * mask + offset * (1 - mask)
        else:
            x = x * mask

        return x.view(n, c, h, w)

数据增强之Gridmask

相关标签: 数据增强