欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

三、pix2pixHD代码解析(dataset处理)

程序员文章站 2023-12-31 18:28:04
...

pix2pixHD代码解析

一、pix2pixHD代码解析(train.py + test.py)
二、pix2pixHD代码解析(options设置)
三、pix2pixHD代码解析(dataset处理)
四、pix2pixHD代码解析(models搭建)

三、pix2pixHD代码解析(dataset处理)

data_loader.py

##########################################################################
# 创建数据集加载主函数
##########################################################################
def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())                                             # 返回的名字为“CustomDatasetDataLoader”
    data_loader.initialize(opt)                                           # 初始化参数
    return data_loader

custom_dataset_data_loader.py

import torch.utils.data
from data.base_data_loader import BaseDataLoader


# 创建数据集
def CreateDataset(opt):
    dataset = None
    from data.aligned_dataset import AlignedDataset
    dataset = AlignedDataset()

    print("dataset [%s] was created" % (dataset.name()))               # 打印数据集名字为‘AlignedDataset’
    dataset.initialize(opt)                                            # 初始化数据集参数
    return dataset                                                     # 返回创建好的数据集


# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)                           # 初始化参数
        self.dataset = CreateDataset(opt)                              # 创建数据集
        self.dataloader = torch.utils.data.DataLoader(                 # 加载创建好的数据集,并自定义相关参数
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self.dataloader                                         # 返回数据集

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)       # 返回加载的数据集长度和一个epoch容许的加载最大容量

aligned_dataset.py


#############################################################################
# 数据读取的方式
#############################################################################

import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image


# 返回一个字典,里面由整理好的数据集:图片 + 类别
class AlignedDataset(BaseDataset):                                           # init里面都是些路径的设置
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot    

        ### input A (label maps)                                             # 标签图的路径
        dir_A = '_A' if self.opt.label_nc == 0 else '_label'
        self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)           # './geometry' + 'train' + '_label'
        ### sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
        # list 的 sort 方法返回的是对已经存在的列表进行操作;
        # 而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作
        # (事实证明直接对string排序,与实际int值排序结果是不一样的,图片名并不是按照从小到大的顺序)
        self.A_paths = sorted(make_dataset(self.dir_A))                      # 返回self.dir_A下的图片路径列表

        ### input B (real images)                                            # 真实图的路径
        if opt.isTrain or opt.use_encoded_image:
            dir_B = '_B' if self.opt.label_nc == 0 else '_img'
            self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)  
            self.B_paths = sorted(make_dataset(self.dir_B))
            # self.B_paths = self.A_paths

        ### instance maps                                                    # 实例图的路径
        if not opt.no_instance:                                              # 如果no_instance为true,则不添加实例图
            self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
            self.inst_paths = sorted(make_dataset(self.dir_inst))
            # self.inst_paths = self.A_paths

        ### load precomputed instance-wise encoded features
        if opt.load_features:                              
            self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
            print('----------- loading features from %s ----------' % self.dir_feat)
            self.feat_paths = sorted(make_dataset(self.dir_feat))            # 本文中没有train_feat图片

        self.dataset_size = len(self.A_paths) 
      
    def __getitem__(self, index):                                            # getitem里是具体的操作,是这个类的重点操作
        ### input A (label maps)                                             # 读取标签图A
        A_path = self.A_paths[index]                                         # 获得图片路径
        # A = Image.open(self.dir_A + '/' + A_path)                                               # 先读取一张图片
        A = Image.open(A_path)
        params = get_params(self.opt, A.size)                                # 根据输入的opt和size,返回随机参数
        if self.opt.label_nc == 0:
            transform_A = get_transform(self.opt, params)
            A_tensor = transform_A(A.convert('RGB'))
        else:
            transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)  # 图像变换
            A_tensor = transform_A(A) * 255.0                                # 对数据预处理,有经过to_tensor操作,再乘255

        B_tensor = inst_tensor = feat_tensor = 0
        ### input B (real images)                                            # 接着读入真实图像B
        if self.opt.isTrain or self.opt.use_encoded_image:
            B_path = self.B_paths[index]
            # B = Image.open(self.dir_B + '/' + B_path).convert('RGB')
            B = Image.open(B_path).convert('RGB')
            transform_B = get_transform(self.opt, params)      
            B_tensor = transform_B(B)

        ### if using instance maps                                           # 接着读入instance,后续还会处理成边缘图,和论文中描述一致。
        if not self.opt.no_instance:                                         # no_instance默认值为true
            inst_path = self.inst_paths[index]
            # inst = Image.open(self.dir_inst + '/' + inst_path)
            inst = Image.open(inst_path)
            inst_tensor = transform_A(inst)                                  # 和semantic的处理方式一样  0-1

            if self.opt.load_features:                                       # 注意self.opt.load_features的作用是是否读取每个类别的预先计算的特征,论文中有10类,由聚类形成的。但默认是不执行的。我本人看论文对这一部分也是一知半解,以后有需求之后再研究。
                feat_path = self.feat_paths[index]            
                feat = Image.open(feat_path).convert('RGB')
                norm = normalize()
                feat_tensor = norm(transform_A(feat))

        input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 
                      'feat': feat_tensor, 'path': A_path}

        return input_dict                                                    # 返回一个字典,记录了上述读取并经过处理的数据集。


    def __len__(self):
        return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize

    def name(self):
        return 'AlignedDataset'

image_folder.py

###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
# 获得指定目录下的图片路径 + 加载路径图片
###############################################################################
import torch.utils.data as data
from PIL import Image
import os

# 本程序支持的图片扩展名
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    ### any()函数用于判断给定的可迭代参数iterable是否全部为False,则返回False,如果有一个为True,则返回True。
    # 元素除了是0、空、FALSE外都算TRUE。
    # 函数等价于:
    # def any(iterable):
    #     for element in iterable:
    #         if element:
    #             return True
    #     return False
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


# 制作数据集:获得数据集的图片路径列表
def make_dataset(dir):                                                 # dir为数据集文件夹路径
    images = []                                                        # 创建空列表
    assert os.path.isdir(dir), '%s is not a valid directory' % dir     # 确认路径存在

    ### os.walk() 方法是一个简单易用的文件、目录遍历器,可以帮助我们高效的处理文件、目录方面的事情
    # top -- 是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)。
    # root 所指的是当前正在遍历的这个文件夹的本身的地址,和输入的os.walk(dir)种的dir一致
    # dirs 是一个 list ,内容是该文件夹中所有的 目录 的名字(不包括子目录),若无则为[]
    # files 同样是 list , 内容是该文件夹中所有的 文件 的名字(不包括子目录),若无则为[]
    for root, _, fnames in sorted(os.walk(dir)):                       # fnames为文件中读取的照片文件
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)                       # 将文件夹路径dir 和 图片名称fname 结合起来
                images.append(path)                                    # 将图片路径存放到image列表里
                # temp = fname
                # images.append(temp)
    return images                                                      # 返回图片路径列表


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)                                      # imgs为root目录下图片路径列表
        if len(imgs) == 0:                                             # 图片数量 = 0 报错
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]                                        # 获取指定图片路径
        img = self.loader(path)                                        # 加载图片
        if self.transform is not None:
            img = self.transform(img)                                  # 图片进行变换
        if self.return_paths:
            return img, path                                           # 返回图片和路径
        else:
            return img                                                 # 仅返回图片

    def __len__(self):
        return len(self.imgs)                                          # 返回指定目录下图片数量

base_dataset.py

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random

class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass


# 这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。
# size:输入图片的尺寸
def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.resize_or_crop == 'resize_and_crop':
        # opt.loadSize为自己输入的尺寸,将图像缩放到这个大小
        new_h = new_w = opt.loadSize                                     # 将宽和高设置为同样大小
    elif opt.resize_or_crop == 'scale_width_and_crop':                   # 我已在opt处设置为‘scale_width_and_crop’
        new_w = opt.loadSize
        new_h = opt.loadSize * h // w                                    # 高度按照原图宽高比计算

    x = random.randint(0, np.maximum(0, new_w - opt.fineSize))           # ???不明白此处的随机数什么意思
    y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
    
    flip = random.random() > 0.5                                         # 随机数是否大于0.5,flip是bool型变量,此行代码意思为随机生成True或者False
    return {'crop_pos': (x, y), 'flip': flip}                            # 最终的返回值,在data.aligned_dataset 45行,当作params传入了下方get_transform()函数


# 图像变换
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
    transform_list = []
    if 'resize' in opt.resize_or_crop:                                   # 若opt.resize_or_crop中有'resize'
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, method))   
    elif 'scale_width' in opt.resize_or_crop:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))

    ### lambda函数也叫匿名函数,即,函数没有具体的名称。先来看一个最简单例子:
    # def f(x):
    #   return x**2
    # print f(4)
    #
    # Python中使用lambda的话,写成这样:
    # g = lambda x : x**2
    # print g(4)

    if 'crop' in opt.resize_or_crop:
        # 使用transforms.Lambda封装其为transforms策略
        transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))

    if opt.resize_or_crop == 'none':
        base = float(2 ** opt.n_downsample_global)
        if opt.netG == 'local':
            base *= (2 ** opt.n_local_enhancers)
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    transform_list += [transforms.ToTensor()]

    if normalize:
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5),         # mean和std均为0.5
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def normalize():    
    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size        
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img    
    w = target_width
    h = int(target_width * oh / ow)    
    return img.resize((w, h), method)

# 随机平移滑动裁剪
def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size                                                       # 输入的尺寸 opt.fineSize
    if (ow > tw or oh > th):        
        return img.crop((x1, y1, x1 + tw, y1 + th))                      # 随机裁剪,因为虽然每次裁剪测大小一样,但是起始点位置不一样
    return img

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

相关标签: GAN PyTorch

上一篇:

下一篇: