欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch学习笔记(10)transforms(4)

程序员文章站 2022-03-21 19:50:06
...

自定义transforms

自定义transforms要素

  1. 仅接收一个参数,返回一个参数
    2.注意上下游的输出与输入

通过类实现多参数传入:

class YourTransforms(object):
     def __init__(self,...):
         ...
     def __call__(self, img):
         ...
         return img

椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或黑点,白点成为盐噪声,黑点称为椒噪声
信噪比(Signal-Noise Rate,SNR) 是衡量噪声的比例,图像中为图像像素的占比

class AddPepperNoise(object):
    def __init__(self,snr,p):
        self.snr = snr
        self.p = p
    def __call__(self, img):
        """
        添加椒盐噪声具体实现过程
        :param img:
        :return:
        """
        return img
class Compose(object):
    def __call__(self,img):
        for t in self.transforms:
            img = t(img)
        return img
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import random
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}


class AddPepperNoise(object):
    """增加椒盐噪声
    Args:
        snr (float): Signal Noise Rate
        p (float): 概率值,依概率执行该操作
    """
    # 默认信噪比90% 保存90%的像素是原始图像
    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) or (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w, c = img_.shape
            # 设置信号的百分比 信噪比
            signal_pct = self.snr
            # 噪声的百分比
            noise_pct = (1 - self.snr)
            # 选取mask  mask值 0,1,2 0代表原始图像 1代表盐噪声 2代表椒噪声
            mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
            mask = np.repeat(mask, c, axis=2)
            img_[mask == 1] = 255   # 盐噪声 白色
            img_[mask == 2] = 0     # 椒噪声 黑色
            return Image.fromarray(img_.astype('uint8')).convert('RGB')
        else:
            return img


# ============================ step 1/5 数据 ============================
split_dir = os.path.join( "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddPepperNoise(0.9, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)


# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):

        inputs, labels = data   # B C H W

        img_tensor = inputs[0, ...]     # C H W
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

如何制定数据增强策略

原则 让训练集合测试集更接近
空间位置:平移
色彩:灰度图,色彩抖动
形状:仿射变换
上下文场景:遮挡,填充