数据增强之FMix
程序员文章站
2024-03-19 23:14:58
...
FMix
- FMix是从随机图像中剪切出任意形状的部分,并将其粘贴到相关图像上;
- 它不同于一般的剪切和粘贴,其需要掩膜来定义图像哪些部分需要考虑;
- 通过对傅里叶空间采样的低频图像进行阈值处理得到掩膜。
这里是原文出处:https://arxiv.org/abs/2002.12047
代码以及效果展示
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image
import torch
import pandas as pd
from utils.fmix import make_low_freq_image, binarise_mask
from sklearn.model_selection import KFold, train_test_split
df = pd.read_csv(os.path.join(DATA_PATH, DataID, 'train.csv'))
kf = KFold(n_splits=5, shuffle=False, random_state=42)
trainset, valset = next(iter(kf.split(df)))
DECAY_POWER = 3
SHAPE = 260
LAMBDA = 0.5
NUM_IMAGES = 4
dataset = ImageData(df, trainset, mode='valid')
dataGen = torch.utils.data.DataLoader(dataset, batch_size=NUM_IMAGES*2, shuffle=True, num_workers=0)
dataIter = iter(dataGen)
batch, target = next(dataIter)
batch1 = batch[:NUM_IMAGES]
batch2 = batch[NUM_IMAGES:]
soft_masks_np = [make_low_freq_image(DECAY_POWER, [SHAPE, SHAPE]) for _ in range(NUM_IMAGES)]
soft_masks = torch.from_numpy(np.stack(soft_masks_np, axis=0)).float().repeat(1, 3, 1, 1)
masks_np = [binarise_mask(mask, LAMBDA, [SHAPE, SHAPE]) for mask in soft_masks_np]
masks = torch.from_numpy(np.stack(masks_np, axis=0)).float().repeat(1, 3, 1, 1)
mix = batch1 * masks + batch2 * (1 - masks)
image = torch.cat((soft_masks, masks, batch1, batch2, mix), 0)
save_image(image, 'fmix_example.png', nrow=NUM_IMAGES, pad_value=1)
plt.figure(figsize=(NUM_IMAGES, 5))
plt.imshow(make_grid(image, nrow=NUM_IMAGES, pad_value=5).permute(1, 2, 0).numpy())
plt.show()
Reference
https://colab.research.google.com/github/ecs-vlc/fmix/blob/master/notebooks/example_masks.ipynb