欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

数据增强之FMix

程序员文章站 2024-03-19 23:14:58
...

FMix

  1. FMix是从随机图像中剪切出任意形状的部分,并将其粘贴到相关图像上;
  2. 它不同于一般的剪切和粘贴,其需要掩膜来定义图像哪些部分需要考虑;
  3. 通过对傅里叶空间采样的低频图像进行阈值处理得到掩膜。
    这里是原文出处:https://arxiv.org/abs/2002.12047

代码以及效果展示

 	import matplotlib.pyplot as plt
    from torchvision.utils import make_grid, save_image
    import torch
    import pandas as pd
    from utils.fmix import make_low_freq_image, binarise_mask
    from sklearn.model_selection import KFold, train_test_split

    df = pd.read_csv(os.path.join(DATA_PATH, DataID, 'train.csv'))
    kf = KFold(n_splits=5, shuffle=False, random_state=42)
    trainset, valset = next(iter(kf.split(df)))
    DECAY_POWER = 3
    SHAPE = 260
    LAMBDA = 0.5
    NUM_IMAGES = 4

    dataset = ImageData(df, trainset, mode='valid')
    dataGen = torch.utils.data.DataLoader(dataset, batch_size=NUM_IMAGES*2, shuffle=True, num_workers=0)
    dataIter = iter(dataGen)
    batch, target = next(dataIter)
    batch1 = batch[:NUM_IMAGES]
    batch2 = batch[NUM_IMAGES:]

    soft_masks_np = [make_low_freq_image(DECAY_POWER, [SHAPE, SHAPE]) for _ in range(NUM_IMAGES)]
    soft_masks = torch.from_numpy(np.stack(soft_masks_np, axis=0)).float().repeat(1, 3, 1, 1)

    masks_np = [binarise_mask(mask, LAMBDA, [SHAPE, SHAPE]) for mask in soft_masks_np]
    masks = torch.from_numpy(np.stack(masks_np, axis=0)).float().repeat(1, 3, 1, 1)

    mix = batch1 * masks + batch2 * (1 - masks)
    image = torch.cat((soft_masks, masks, batch1, batch2, mix), 0)
    save_image(image, 'fmix_example.png', nrow=NUM_IMAGES, pad_value=1)

    plt.figure(figsize=(NUM_IMAGES, 5))
    plt.imshow(make_grid(image, nrow=NUM_IMAGES, pad_value=5).permute(1, 2, 0).numpy())
    plt.show()

数据增强之FMix

Reference

https://colab.research.google.com/github/ecs-vlc/fmix/blob/master/notebooks/example_masks.ipynb