PyTorch学习笔记(10)transforms(4)
程序员文章站
2022-03-21 19:50:06
...
自定义transforms
自定义transforms要素
- 仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
通过类实现多参数传入:
class YourTransforms(object):
def __init__(self,...):
...
def __call__(self, img):
...
return img
椒盐噪声
椒盐噪声又称为脉冲噪声,是一种随机出现的白点或黑点,白点成为盐噪声,黑点称为椒噪声
信噪比(Signal-Noise Rate,SNR) 是衡量噪声的比例,图像中为图像像素的占比
class AddPepperNoise(object):
def __init__(self,snr,p):
self.snr = snr
self.p = p
def __call__(self, img):
"""
添加椒盐噪声具体实现过程
:param img:
:return:
"""
return img
class Compose(object):
def __call__(self,img):
for t in self.transforms:
img = t(img)
return img
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import random
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed(1) # 设置随机种子
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}
class AddPepperNoise(object):
"""增加椒盐噪声
Args:
snr (float): Signal Noise Rate
p (float): 概率值,依概率执行该操作
"""
# 默认信噪比90% 保存90%的像素是原始图像
def __init__(self, snr, p=0.9):
assert isinstance(snr, float) or (isinstance(p, float))
self.snr = snr
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image.
"""
if random.uniform(0, 1) < self.p:
img_ = np.array(img).copy()
h, w, c = img_.shape
# 设置信号的百分比 信噪比
signal_pct = self.snr
# 噪声的百分比
noise_pct = (1 - self.snr)
# 选取mask mask值 0,1,2 0代表原始图像 1代表盐噪声 2代表椒噪声
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255 # 盐噪声 白色
img_[mask == 2] = 0 # 椒噪声 黑色
return Image.fromarray(img_.astype('uint8')).convert('RGB')
else:
return img
# ============================ step 1/5 数据 ============================
split_dir = os.path.join( "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
AddPepperNoise(0.9, p=0.5),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
for i, data in enumerate(train_loader):
inputs, labels = data # B C H W
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform)
plt.imshow(img)
plt.show()
plt.pause(0.5)
plt.close()
如何制定数据增强策略
原则 让训练集合测试集更接近
空间位置:平移
色彩:灰度图,色彩抖动
形状:仿射变换
上下文场景:遮挡,填充