欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

ESPCN 论文复现、代码注解、超级详细

程序员文章站 2022-06-28 16:59:12
ESPCN 论文复现、代码注解、超级详细环境文章解释的代码顺序即为代码阅读顺序;环境:ubuntu16.04pytorch 1.2torchvision 0.40cuda 9.2python 3.6github: [https://github.com/leftthomas/ESPCN](https://github.com/leftthomas/ESPCN)上面github应该是官方代码,但是用的pytorch0.4以前的版本应该是,cuda是8.0,python是2.7...

ESPCN 论文复现、代码注解、超级详细

环境

文章解释的代码顺序即为代码阅读顺序;
环境:
	ubuntu16.04
	pytorch 1.2
	torchvision 0.40
	cuda 9.2
	python 3.6
	github: [https://github.com/leftthomas/ESPCN](https://github.com/leftthomas/ESPCN)
	上面github应该是官方代码,但是用的pytorch0.4以前的版本应该是,cuda是8.0,python是2.7
	而本文是在官网上的基础上改动了几个函数并翻译,供和我一样的萌新参考。如发现有错误请提醒一下大家共同学习

data_utils.py

import argparse
import os
from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset
"""
    torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。

    通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,
    所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,
    用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
    总之,通过torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷。

    __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)

    __getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作

"""
from torchvision.transforms import Compose, CenterCrop, Scale
from tqdm import tqdm


def is_image_file(filename):         #图片文件夹,以及打开格式 
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG'])
    """
    endswith():方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。可选参数"start"与"end"为检索字符串的开始与结束位置
    extension: 类扩展
    any():python doc中得说明,意思就是当传入空可迭代对象时返回False,当可迭代对象中有任意一个不为False,则返回True
    即有后面任意一种扩展名都返回  True
    """


def is_video_file(filename):         #视屏文件夹,以及打开格式
    return any(filename.endswith(extension) for extension in ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'])


def calculate_valid_crop_size(crop_size, upscale_factor):                   #计算有效切割尺寸
    return crop_size - (crop_size % upscale_factor)


def input_transform(crop_size, upscale_factor):                              #输入图片切割
    return Compose([                                                        #Conpose 用来管理transfrom(数据扩大操作)中的各种类
        CenterCrop(crop_size),                                              #将图片进行中心切割得到 给定size大小图片并用scale缩放
        Scale(crop_size // upscale_factor, interpolation=Image.BICUBIC)     #scale 是缩放   // 代表整除    interpolation是插入类型
    ])


def target_transform(crop_size):                                        #标签图片切割
    return Compose([
        CenterCrop(crop_size)
    ])


class DatasetFromFolder(Dataset):                                         #从文件夹获取数据
    def __init__(self, dataset_dir, upscale_factor, input_transform=None, target_transform=None):
        super(DatasetFromFolder, self).__init__()
        self.image_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/data'             #路径是 D:\论文代码\ESPCN网络框架\ESPCN-master\data\train\SRF_3\data
        self.target_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/target'          #路径是D:\论文代码\ESPCN网络框架\ESPCN-master\data\train\SRF_3\target
        #join  连接两个或多个路径名组件,根据需要插入'/'。如果任何组件是绝对路径,则之前的所有路径组件将被丢弃。最后部分为空将生成以分隔符结束的路径
        #listdir == os
        #join   == os.path
        self.image_filenames = [join(self.image_dir, x) for x in listdir(self.image_dir) if is_image_file(x)]   #遍历image_dir文件夹下的文件 , 符合is_image_file(x)定义格式文件,则把该文件和路径连接
        self.target_filenames = [join(self.target_dir, x) for x in listdir(self.target_dir) if is_image_file(x)]   #target_dir遍历文件夹下的文件 , 符合is_image_file(x)定义格式文件,则把该文件和路径连接
        self.input_transform = input_transform      #两个空张量,
        self.target_transform = target_transform

    def __getitem__(self, index):
        image, _, _ = Image.open(self.image_filenames[index]).convert('YCbCr').split()  #image.open()打开图片
        #self.image_filenames([index])是从( dataset_dir + '/SRF_' + str(upscale_factor) + '/data' )
        #中索引一张图片用图像处理库 PIL 实现转换其格式为'YCbCr',并赋值给image

        #PIL的九种不同模式:1(黑白图,二值图),L(灰度图),P(八色彩图),RGB,RGBA,CMYK(印刷四分色图),YCbCr(模式“YCbCr”为24位彩色图像,Y亮度Cb蓝色色度Cr红色色度),I,F 

        target, _, _ = Image.open(self.target_filenames[index]).convert('YCbCr').split()  #转换数据图像的格式

        if self.input_transform:
            image = self.input_transform(image)          #将转换后的image 图片切割至给定大小
        if self.target_transform:
            target = self.target_transform(target)       #将转换后的target 图片切割至给定大小

        return image, target

    def __len__(self):
        return len(self.image_filenames)


def generate_dataset(data_type, upscale_factor):
    images_name = [x for x in listdir('data/VOC2012/' + data_type) if is_image_file(x)]  
    #listdir('data/VOC2012/' + data_type)里有扩展名为 ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'] 的文件则返回文件并赋值到 image_name 中

    crop_size = calculate_valid_crop_size(256, upscale_factor)  #返回  crop_size = crop_size - (crop_size % upscale_factor)  256 - (256 % 3)= 255   %:求余数
    lr_transform = input_transform(crop_size, upscale_factor)   #把图片中心剪切:crop_size = 255 大小 ; 缩放 crop_size // upscale_factor = 255 // 3 = 0 
    hr_transform = target_transform(crop_size)                  #target 和 input 一样操作

    root = 'data/' + data_type              #data_type是 'data/' 下待创建的: train ; val ; test ;
    if not os.path.exists(root):            #os.makedirs :判断路径存不存在,不i存在创建它
        os.makedirs(root)
    path = root + '/SRF_' + str(upscale_factor)   #合并路径
    if not os.path.exists(path):
        os.makedirs(path)
    image_path = path + '/data'
    if not os.path.exists(image_path):
        os.makedirs(image_path)
    target_path = path + '/target'
    if not os.path.exists(target_path):
        os.makedirs(target_path)
    """
    这一串一直检擦路径存在于否,不存在就创建
    直到创建:D:\论文代码\ESPCN网络框架\ESPCN-master\data        \data_type\SRF_2\target(标签)
           :D:\论文代码\ESPCN网络框架\ESPCN-master\data        \data_type\SRF_2\data(训练集)
           : data_type是 'data/' 下待创建的: train ; val ; test ;
    """

    for image_name in tqdm(images_name, desc='generate ' + data_type + ' dataset with upscale factor = '+ str(upscale_factor) + ' from VOC2012'):
        #from VOC2012 generate data_type(train,val,test)  dataset with upscale factor = upscale_factor;
        """
        tqdm:是一个快速,可扩展的Python进度条,可以在Python长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)即可完成进度条
        #images_name = [x for x in listdir('data/VOC2012/' + data_type) if is_image_file(x)]
        #listdir('data/VOC2012/' + data_type)里有扩展名为 ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'] 的文件名则返回文件并赋值到 image_name 中
        """
        image = Image.open('data/VOC2012/' + data_type + '/' + image_name) #将image_name 存的文件名的文件赋值给image
        target = image.copy()
        image = lr_transform(image)             #将VOC2012里面的图片中心切割缩放后放到image 中
        target = hr_transform(target)           #将VOC2012里面的图片中心切割后放到target 中

        image.save(image_path + '/' + image_name)      #把处理后的image  放到  image_path + '/' + image_name  中
        #image_path = \data\data_type\SRF_2\data(训练集)
        target.save(target_path + '/' + image_name)    #把处理后的target  放到 target_path + '/' + image_name  中
        #target_path = \data\data_type\SRF_2\data(训练集)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate Super Resolution Dataset')  #用argparse包中ArgumentParser类生成一个参数解析器 parser ;description是描述参数解析器的作用:生成超分辨率数据集
    parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')  #增加一个参数 upscale_factor ,后面是该参数的介绍
    opt = parser.parse_args()                #用对象parse_args()获取参数解析器的参数
    UPSCALE_FACTOR = opt.upscale_factor      #将解析器的参数赋值给 UPSCALE_FACTOR

    generate_dataset(data_type='train', upscale_factor=UPSCALE_FACTOR)    #调用generate_dataset()函数对 VOC2012里面的数据经过转换、剪切、分别制作训练集和验证集
    generate_dataset(data_type='val', upscale_factor=UPSCALE_FACTOR)
    

model.py

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.sigmoid(self.pixel_shuffle(self.conv3(x)))
        return x


if __name__ == "__main__":
    model = Net(upscale_factor=3)
    print(model)

psnrmeter.py

from math import log10

import torch
from torchnet.meter import meter
"""
meter:提供了一种标准化的方法来测量一系列不同的测量值,这使得测量模型的各种属性变得很容易
三个方法:
    add() 这给 meter 增加了一个观测值
    value() 它将所有的观测结果都考虑在内,返回 meter 的值
    reset() 这样就删除了之前添加的所有观测结果,重新设置了 meter 
"""


class PSNRMeter(meter.Meter):
    def __init__(self):                         #每次调用这个类都刷新一次
        super(PSNRMeter, self).__init__()
        self.reset()

    def reset(self):                            #这样就删除了之前添加的所有观测结果,重新设置了 meter
        self.n = 0
        self.sesum = 0.0

    def add(self, output, target):              #这给 meter 增加了一个观测值
        if not torch.is_tensor(output) and not torch.is_tensor(target):        #确保out target 是张量
            output = torch.from_numpy(output)
            target = torch.from_numpy(target)
        output = output.cpu()                   #cpu计算 变量 output 、 target
        target = target.cpu()
        self.n += output.numel()                #n = n + output.nume1()   n应该是标签数目
        self.sesum += torch.sum((output - target) ** 2)

    def value(self):                            #它将所有的观测结果都考虑在内,返回 meter 的值
        mse = self.sesum / max(1, self.n)
        psnr = 10 * log10(1 / mse)
        return psnr                             #psnr是评价标准: 峰值信噪比

train.py

import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchnet as tnt
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torchnet.engine import Engine
"""
torchnet 是用于 torch 的代码复用和模块化编程的框架
四种方法:
    Dataset : 各种不同的方式处理数据
    Engine: 各种机器学习算法
    Meter: 性能度量评估
    Log: 计算对数
Engine:
    它将训练过程和测试过程进行包装,抽象成一个类,提供train和test方法和一个hooks.(这部分文档是问题的)文档中的描述应该是torch.tensor中的hook,原理一致,
    只不过tensor中的hook是在变量forward或者backward的时候执行(两种hook)
    hooks包括on_start, on_sample, on_forward, on_update,  on_end_epoch,  on_end,可以自己制定函数,在开始,load数据,forward,更新还有epoch结束以及训练结束时执行。
    一般是用开查看和保存模型训练过程的一些结果
"""
from torchnet.logger import VisdomPlotLogger
#用于记录一些评估结果和可视化
from tqdm import tqdm

from data_utils import DatasetFromFolder
from model import Net
from psnrmeter import PSNRMeter


def processor(sample):                   #定义处理器
    """
    processor()函数:
    输入 :sample是输入的数据集包含训练集数据(data)和标签(target)
    输出 :预测值 y 和损失值 loss
    """
    data, target, training = sample
    data = Variable(data)               #from torch.autograd import Variable   就是将变量 grad_requirs = True
    target = Variable(target)
    if torch.cuda.is_available():       #检测  GPU  是佛可用
        data = data.cuda()              #将变量放入cuda中
        target = target.cuda()          

    output = model(data)                #调用model  这个data应该是放大因子upscale_factor   output 是输出预测值
    """
        #import torch.nn as nn
        #import torch.nn.functional as F


        #class Net(nn.Module):  
               
          ## torch.nn是专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法。
          ##定义自已的网络:
                需要继承nn.Module类,并实现forward方法。
                一般把网络中具有可学习参数的层放在构造函数__init__()中,
                不具有可学习参数的层(如ReLU)可放在构造函数中,也可不放在构造函数中(而在forward中使用nn.functional来代替)
    
                只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)。
                在forward函数中可以使用任何Variable支持的函数,毕竟在整个pytorch构建的图中,是Variable在流动。还可以使用
                if,for,print,log等python语法.
    
                注:Pytorch基于nn.Module构建的模型中,只支持mini-batch的Variable输入方式,
                比如,只有一张输入图片,也需要变成 N x C x H x W 的形式:
    
                input_image = torch.FloatTensor(1, 28, 28)
                input_image = Variable(input_image)
                input_image = input_image.unsqueeze(0)   # 1 x 1 x 28 x 28

            def __init__(self, upscale_factor):                                      ##upscale_factor  :上采样因素 就是放大倍数
                super(Net, self).__init__()

                self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))                ##self.conv2d = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=4,stride=2,padding=1)
                self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
                self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))     ##此处的upscale_factor 就是为后续的上采样实现  
                self.pixel_shuffle = nn.PixelShuffle(upscale_factor)                 ##PixelShuffle  亚像素实现上采样: input:(m , c*upscale_factor**2 , h , w)
            #                                                                                                         oueput:(m , c , h*upscale_facter , w*upscale_facter)

            def forward(self, x):                                                     ##  nn.Module实现的层(layer)是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数
            #                                                                         ##  nn.functional中的函数更像是纯函数,由def functional(input)定义  一般定义不可学习参数的层,如激活、池化
                x = F.tanh(self.conv1(x))              ##定义激活函数 import torch.nn.functional as F
                x = F.tanh(self.conv2(x))
                x = F.sigmoid(self.pixel_shuffle(self.conv3(x)))
                return x


        if __name__ == "__main__":
            model = Net(upscale_factor=3)
            print(model)
    """

    loss = criterion(output, target)      #调用损失函数  一般用法:  criterion = LossCriterion()
    #                                      output是预测值     target是标签真实值             loss = criterion(y, Y)

    return loss, output


def on_sample(state):
    state['sample'].append(state['train'])         #给 train 这个键赋值


def reset_meters():
    meter_psnr.reset()         #清空:评价标准 和 损失数据 
    meter_loss.reset()


def on_forward(state):
    meter_psnr.add(state['output'].data, state['sample'][1])   #计算评价标准:tensor.data 切断反向传播,因为评价标准不要学习参数(可用tensor.detach()更好) add():把output 和 sample[1] 元素逐个相加
    meter_loss.add(state['loss'].item())                       #.team()返回一个可遍历的数组 ;计算的损失逐个相加


def on_start_epoch(state):
    reset_meters()               #清空:评价标准 和 损失数据 

    scheduler.step()      #在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。
    #所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次

    state['iterator'] = tqdm(state['iterator']) #iterator迭代器,返回迭代进度条提示信息
    #tqdm: 是一个快速,可扩展的Python进度条,可以在Python长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)即可完成进度条


def on_end_epoch(state):
    print('[Epoch %d] Train Loss: %.4f (PSNR: %.2f db)' % (state['epoch'], meter_loss.value()[0], meter_psnr.value()))  #打印训练的迭代次数对应的 损失 峰值信噪比

    train_loss_logger.log(state['epoch'], meter_loss.value()[0])      #计算训练集损失和峰值信噪比的对数,以迭代次数为底
    train_psnr_logger.log(state['epoch'], meter_psnr.value())

    reset_meters()                     #清空:评价标准 和 损失数据

    engine.test(processor, val_loader)    #用engine 封装test的方法, processor()函数是向前传播计算预测值和损失; 
    #后面定义   :val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False)

    val_loss_logger.log(state['epoch'], meter_loss.value()[0])         #计算验证集损失和峰值信噪比的对数,以迭代次数为底
    val_psnr_logger.log(state['epoch'], meter_psnr.value())

    print('[Epoch %d] Val Loss: %.4f (PSNR: %.2f db)' % (state['epoch'], meter_loss.value()[0], meter_psnr.value()))    #打印验证的迭代次数对应的 损失 峰值信噪比

    torch.save(model.state_dict(), 'epochs/epoch_%d_%d.pt' % (UPSCALE_FACTOR, state['epoch']))  #state_dict():返回一个包含模块的整个状态的字典(返回model字典里是:前向卷积和激活函数)
    #save():将对象(返回的model字典)保存到磁盘文件中。


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Train Super Resolution') ##用argparse包中ArgumentParser类生成一个参数解析器 parser ;description是描述参数解析器的作用:训练超分
    parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')  #增加一个  upscale_factor参数 :超分上采样因数
    parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number')   #增加一个  num_epochs参数:超分迭代次数
    opt = parser.parse_args()        ##用对象parse_args()获取参数解析器的参数到 opt里

    UPSCALE_FACTOR = opt.upscale_factor
    NUM_EPOCHS = opt.num_epochs

    train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),target_transform=transforms.ToTensor())
    #调用 data_utils.py 里的类 DatasetFromFolder:对数据进行处理、转换、剪切 得到 image(训练用的data) 和 target(标签# ) 后面赋值给dataset 进行封装batch_size 进行训练
    
    val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),target_transform=transforms.ToTensor())

    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    """
    先简单的介绍一下DataLoader
    它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),
    该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

    # torch.utils.data.DataLoader() 将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
       torch.utils.data.DataLoader(
                            dataset,#数据加载
                            batch_size = 1,#批处理大小设置
                            shuffle = False,#是否进项洗牌操作
                            sampler = None,#指定数据加载中使用的索引/键的序列
                            batch_sampler = None,#和sampler类似
                            num_workers = 0,#是否进行多进程加载数据设置
                            collat​​e_fn = None,#是否合并样本列表以形成一小批Tensor
                            pin_memory = False,#如果True,数据加载器会在返回之前将Tensors复制到CUDA固定内存
                            drop_last = False,#True如果数据集大小不能被批处理大小整除,则设置为删除最后一个不完整的批处理。
                            timeout = 0,#如果为正,则为从工作人员收集批处理的超时值
                            worker_init_fn = None )

    """
    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False)   #对训练集和验证集设置  batch_size 批量处理

    model = Net(upscale_factor=UPSCALE_FACTOR)  #前向传播
    criterion = nn.MSELoss()                  #计算损失
    if torch.cuda.is_available():           #把model 里的计算添加到cuda 中计算
        model = model.cuda()
        criterion = criterion.cuda()

    print('# parameters:', sum(param.numel() for param in model.parameters()))

    optimizer = optim.Adam(model.parameters(), lr=1e-2)                         #优化方法为adma;  model.parameters :待优化的参数 ; lr:学习率
    scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)          #学习率变化:milestones表示学习率在 30 80 时候各变化一次 变化 gamma = 0.1 倍

    engine = Engine()          #封装test train val 的训练过程
    meter_loss = tnt.meter.AverageValueMeter()  #计算平均回归损失 
    meter_psnr = PSNRMeter()  #获取峰值信噪比

    train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})    #可视化
    train_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Train PSNR'})
    val_loss_logger = VisdomPlotLogger('line', opts={'title': 'Val Loss'})
    val_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Val PSNR'})

    engine.hooks['on_sample'] = on_sample                                         #engine封装的hook,即训练的过程文档
    engine.hooks['on_forward'] = on_forward
    engine.hooks['on_start_epoch'] = on_start_epoch
    engine.hooks['on_end_epoch'] = on_end_epoch

    engine.train(processor, train_loader, maxepoch=NUM_EPOCHS, optimizer=optimizer) #engine封装的训练

结束!图片测试和视频测试都和 train.py 大同小异 供大家自己解读;
萌新第一次发博客求点赞 :)

本文地址:https://blog.csdn.net/qq_38477064/article/details/109641687