ESPCN 论文复现、代码注解、超级详细
程序员文章站
2022-04-07 18:06:44
ESPCN 论文复现、代码注解、超级详细环境文章解释的代码顺序即为代码阅读顺序;环境:ubuntu16.04pytorch 1.2torchvision 0.40cuda 9.2python 3.6github: [https://github.com/leftthomas/ESPCN](https://github.com/leftthomas/ESPCN)上面github应该是官方代码,但是用的pytorch0.4以前的版本应该是,cuda是8.0,python是2.7...
ESPCN 论文复现、代码注解、超级详细
环境
文章解释的代码顺序即为代码阅读顺序;
环境:
ubuntu16.04
pytorch 1.2
torchvision 0.40
cuda 9.2
python 3.6
github: [https://github.com/leftthomas/ESPCN](https://github.com/leftthomas/ESPCN)
上面github应该是官方代码,但是用的pytorch0.4以前的版本应该是,cuda是8.0,python是2.7
而本文是在官网上的基础上改动了几个函数并翻译,供和我一样的萌新参考。如发现有错误请提醒一下大家共同学习
data_utils.py
import argparse
import os
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
"""
torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。
通过继承torch.utils.data.Dataset的这个抽象类,我们可以定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,
所以pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader类来定义一个新的迭代器,
用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
总之,通过torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷。
__len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)
__getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作
"""
from torchvision.transforms import Compose, CenterCrop, Scale
from tqdm import tqdm
def is_image_file(filename): #图片文件夹,以及打开格式
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG'])
"""
endswith():方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。可选参数"start"与"end"为检索字符串的开始与结束位置
extension: 类扩展
any():python doc中得说明,意思就是当传入空可迭代对象时返回False,当可迭代对象中有任意一个不为False,则返回True
即有后面任意一种扩展名都返回 True
"""
def is_video_file(filename): #视屏文件夹,以及打开格式
return any(filename.endswith(extension) for extension in ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'])
def calculate_valid_crop_size(crop_size, upscale_factor): #计算有效切割尺寸
return crop_size - (crop_size % upscale_factor)
def input_transform(crop_size, upscale_factor): #输入图片切割
return Compose([ #Conpose 用来管理transfrom(数据扩大操作)中的各种类
CenterCrop(crop_size), #将图片进行中心切割得到 给定size大小图片并用scale缩放
Scale(crop_size // upscale_factor, interpolation=Image.BICUBIC) #scale 是缩放 // 代表整除 interpolation是插入类型
])
def target_transform(crop_size): #标签图片切割
return Compose([
CenterCrop(crop_size)
])
class DatasetFromFolder(Dataset): #从文件夹获取数据
def __init__(self, dataset_dir, upscale_factor, input_transform=None, target_transform=None):
super(DatasetFromFolder, self).__init__()
self.image_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/data' #路径是 D:\论文代码\ESPCN网络框架\ESPCN-master\data\train\SRF_3\data
self.target_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/target' #路径是D:\论文代码\ESPCN网络框架\ESPCN-master\data\train\SRF_3\target
#join 连接两个或多个路径名组件,根据需要插入'/'。如果任何组件是绝对路径,则之前的所有路径组件将被丢弃。最后部分为空将生成以分隔符结束的路径
#listdir == os
#join == os.path
self.image_filenames = [join(self.image_dir, x) for x in listdir(self.image_dir) if is_image_file(x)] #遍历image_dir文件夹下的文件 , 符合is_image_file(x)定义格式文件,则把该文件和路径连接
self.target_filenames = [join(self.target_dir, x) for x in listdir(self.target_dir) if is_image_file(x)] #target_dir遍历文件夹下的文件 , 符合is_image_file(x)定义格式文件,则把该文件和路径连接
self.input_transform = input_transform #两个空张量,
self.target_transform = target_transform
def __getitem__(self, index):
image, _, _ = Image.open(self.image_filenames[index]).convert('YCbCr').split() #image.open()打开图片
#self.image_filenames([index])是从( dataset_dir + '/SRF_' + str(upscale_factor) + '/data' )
#中索引一张图片用图像处理库 PIL 实现转换其格式为'YCbCr',并赋值给image
#PIL的九种不同模式:1(黑白图,二值图),L(灰度图),P(八色彩图),RGB,RGBA,CMYK(印刷四分色图),YCbCr(模式“YCbCr”为24位彩色图像,Y亮度Cb蓝色色度Cr红色色度),I,F
target, _, _ = Image.open(self.target_filenames[index]).convert('YCbCr').split() #转换数据图像的格式
if self.input_transform:
image = self.input_transform(image) #将转换后的image 图片切割至给定大小
if self.target_transform:
target = self.target_transform(target) #将转换后的target 图片切割至给定大小
return image, target
def __len__(self):
return len(self.image_filenames)
def generate_dataset(data_type, upscale_factor):
images_name = [x for x in listdir('data/VOC2012/' + data_type) if is_image_file(x)]
#listdir('data/VOC2012/' + data_type)里有扩展名为 ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'] 的文件则返回文件并赋值到 image_name 中
crop_size = calculate_valid_crop_size(256, upscale_factor) #返回 crop_size = crop_size - (crop_size % upscale_factor) 256 - (256 % 3)= 255 %:求余数
lr_transform = input_transform(crop_size, upscale_factor) #把图片中心剪切:crop_size = 255 大小 ; 缩放 crop_size // upscale_factor = 255 // 3 = 0
hr_transform = target_transform(crop_size) #target 和 input 一样操作
root = 'data/' + data_type #data_type是 'data/' 下待创建的: train ; val ; test ;
if not os.path.exists(root): #os.makedirs :判断路径存不存在,不i存在创建它
os.makedirs(root)
path = root + '/SRF_' + str(upscale_factor) #合并路径
if not os.path.exists(path):
os.makedirs(path)
image_path = path + '/data'
if not os.path.exists(image_path):
os.makedirs(image_path)
target_path = path + '/target'
if not os.path.exists(target_path):
os.makedirs(target_path)
"""
这一串一直检擦路径存在于否,不存在就创建
直到创建:D:\论文代码\ESPCN网络框架\ESPCN-master\data \data_type\SRF_2\target(标签)
:D:\论文代码\ESPCN网络框架\ESPCN-master\data \data_type\SRF_2\data(训练集)
: data_type是 'data/' 下待创建的: train ; val ; test ;
"""
for image_name in tqdm(images_name, desc='generate ' + data_type + ' dataset with upscale factor = '+ str(upscale_factor) + ' from VOC2012'):
#from VOC2012 generate data_type(train,val,test) dataset with upscale factor = upscale_factor;
"""
tqdm:是一个快速,可扩展的Python进度条,可以在Python长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)即可完成进度条
#images_name = [x for x in listdir('data/VOC2012/' + data_type) if is_image_file(x)]
#listdir('data/VOC2012/' + data_type)里有扩展名为 ['.mp4', '.avi', '.mpg', '.mkv', '.wmv', '.flv'] 的文件名则返回文件并赋值到 image_name 中
"""
image = Image.open('data/VOC2012/' + data_type + '/' + image_name) #将image_name 存的文件名的文件赋值给image
target = image.copy()
image = lr_transform(image) #将VOC2012里面的图片中心切割缩放后放到image 中
target = hr_transform(target) #将VOC2012里面的图片中心切割后放到target 中
image.save(image_path + '/' + image_name) #把处理后的image 放到 image_path + '/' + image_name 中
#image_path = \data\data_type\SRF_2\data(训练集)
target.save(target_path + '/' + image_name) #把处理后的target 放到 target_path + '/' + image_name 中
#target_path = \data\data_type\SRF_2\data(训练集)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate Super Resolution Dataset') #用argparse包中ArgumentParser类生成一个参数解析器 parser ;description是描述参数解析器的作用:生成超分辨率数据集
parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor') #增加一个参数 upscale_factor ,后面是该参数的介绍
opt = parser.parse_args() #用对象parse_args()获取参数解析器的参数
UPSCALE_FACTOR = opt.upscale_factor #将解析器的参数赋值给 UPSCALE_FACTOR
generate_dataset(data_type='train', upscale_factor=UPSCALE_FACTOR) #调用generate_dataset()函数对 VOC2012里面的数据经过转换、剪切、分别制作训练集和验证集
generate_dataset(data_type='val', upscale_factor=UPSCALE_FACTOR)
model.py
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = F.tanh(self.conv1(x))
x = F.tanh(self.conv2(x))
x = F.sigmoid(self.pixel_shuffle(self.conv3(x)))
return x
if __name__ == "__main__":
model = Net(upscale_factor=3)
print(model)
psnrmeter.py
from math import log10
import torch
from torchnet.meter import meter
"""
meter:提供了一种标准化的方法来测量一系列不同的测量值,这使得测量模型的各种属性变得很容易
三个方法:
add() 这给 meter 增加了一个观测值
value() 它将所有的观测结果都考虑在内,返回 meter 的值
reset() 这样就删除了之前添加的所有观测结果,重新设置了 meter
"""
class PSNRMeter(meter.Meter):
def __init__(self): #每次调用这个类都刷新一次
super(PSNRMeter, self).__init__()
self.reset()
def reset(self): #这样就删除了之前添加的所有观测结果,重新设置了 meter
self.n = 0
self.sesum = 0.0
def add(self, output, target): #这给 meter 增加了一个观测值
if not torch.is_tensor(output) and not torch.is_tensor(target): #确保out target 是张量
output = torch.from_numpy(output)
target = torch.from_numpy(target)
output = output.cpu() #cpu计算 变量 output 、 target
target = target.cpu()
self.n += output.numel() #n = n + output.nume1() n应该是标签数目
self.sesum += torch.sum((output - target) ** 2)
def value(self): #它将所有的观测结果都考虑在内,返回 meter 的值
mse = self.sesum / max(1, self.n)
psnr = 10 * log10(1 / mse)
return psnr #psnr是评价标准: 峰值信噪比
train.py
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torchnet as tnt
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torchnet.engine import Engine
"""
torchnet 是用于 torch 的代码复用和模块化编程的框架
四种方法:
Dataset : 各种不同的方式处理数据
Engine: 各种机器学习算法
Meter: 性能度量评估
Log: 计算对数
Engine:
它将训练过程和测试过程进行包装,抽象成一个类,提供train和test方法和一个hooks.(这部分文档是问题的)文档中的描述应该是torch.tensor中的hook,原理一致,
只不过tensor中的hook是在变量forward或者backward的时候执行(两种hook)
hooks包括on_start, on_sample, on_forward, on_update, on_end_epoch, on_end,可以自己制定函数,在开始,load数据,forward,更新还有epoch结束以及训练结束时执行。
一般是用开查看和保存模型训练过程的一些结果
"""
from torchnet.logger import VisdomPlotLogger
#用于记录一些评估结果和可视化
from tqdm import tqdm
from data_utils import DatasetFromFolder
from model import Net
from psnrmeter import PSNRMeter
def processor(sample): #定义处理器
"""
processor()函数:
输入 :sample是输入的数据集包含训练集数据(data)和标签(target)
输出 :预测值 y 和损失值 loss
"""
data, target, training = sample
data = Variable(data) #from torch.autograd import Variable 就是将变量 grad_requirs = True
target = Variable(target)
if torch.cuda.is_available(): #检测 GPU 是佛可用
data = data.cuda() #将变量放入cuda中
target = target.cuda()
output = model(data) #调用model 这个data应该是放大因子upscale_factor output 是输出预测值
"""
#import torch.nn as nn
#import torch.nn.functional as F
#class Net(nn.Module):
## torch.nn是专门为神经网络设计的模块化接口。nn构建于autograd之上,可以用来定义和运行神经网络。nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法。
##定义自已的网络:
需要继承nn.Module类,并实现forward方法。
一般把网络中具有可学习参数的层放在构造函数__init__()中,
不具有可学习参数的层(如ReLU)可放在构造函数中,也可不放在构造函数中(而在forward中使用nn.functional来代替)
只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)。
在forward函数中可以使用任何Variable支持的函数,毕竟在整个pytorch构建的图中,是Variable在流动。还可以使用
if,for,print,log等python语法.
注:Pytorch基于nn.Module构建的模型中,只支持mini-batch的Variable输入方式,
比如,只有一张输入图片,也需要变成 N x C x H x W 的形式:
input_image = torch.FloatTensor(1, 28, 28)
input_image = Variable(input_image)
input_image = input_image.unsqueeze(0) # 1 x 1 x 28 x 28
def __init__(self, upscale_factor): ##upscale_factor :上采样因素 就是放大倍数
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) ##self.conv2d = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=4,stride=2,padding=1)
self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1)) ##此处的upscale_factor 就是为后续的上采样实现
self.pixel_shuffle = nn.PixelShuffle(upscale_factor) ##PixelShuffle 亚像素实现上采样: input:(m , c*upscale_factor**2 , h , w)
# oueput:(m , c , h*upscale_facter , w*upscale_facter)
def forward(self, x): ## nn.Module实现的层(layer)是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数
# ## nn.functional中的函数更像是纯函数,由def functional(input)定义 一般定义不可学习参数的层,如激活、池化
x = F.tanh(self.conv1(x)) ##定义激活函数 import torch.nn.functional as F
x = F.tanh(self.conv2(x))
x = F.sigmoid(self.pixel_shuffle(self.conv3(x)))
return x
if __name__ == "__main__":
model = Net(upscale_factor=3)
print(model)
"""
loss = criterion(output, target) #调用损失函数 一般用法: criterion = LossCriterion()
# output是预测值 target是标签真实值 loss = criterion(y, Y)
return loss, output
def on_sample(state):
state['sample'].append(state['train']) #给 train 这个键赋值
def reset_meters():
meter_psnr.reset() #清空:评价标准 和 损失数据
meter_loss.reset()
def on_forward(state):
meter_psnr.add(state['output'].data, state['sample'][1]) #计算评价标准:tensor.data 切断反向传播,因为评价标准不要学习参数(可用tensor.detach()更好) add():把output 和 sample[1] 元素逐个相加
meter_loss.add(state['loss'].item()) #.team()返回一个可遍历的数组 ;计算的损失逐个相加
def on_start_epoch(state):
reset_meters() #清空:评价标准 和 损失数据
scheduler.step() #在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。
#所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次
state['iterator'] = tqdm(state['iterator']) #iterator迭代器,返回迭代进度条提示信息
#tqdm: 是一个快速,可扩展的Python进度条,可以在Python长循环中添加一个进度提示信息,用户只需要封装任意的迭代器tqdm(iterator)即可完成进度条
def on_end_epoch(state):
print('[Epoch %d] Train Loss: %.4f (PSNR: %.2f db)' % (state['epoch'], meter_loss.value()[0], meter_psnr.value())) #打印训练的迭代次数对应的 损失 峰值信噪比
train_loss_logger.log(state['epoch'], meter_loss.value()[0]) #计算训练集损失和峰值信噪比的对数,以迭代次数为底
train_psnr_logger.log(state['epoch'], meter_psnr.value())
reset_meters() #清空:评价标准 和 损失数据
engine.test(processor, val_loader) #用engine 封装test的方法, processor()函数是向前传播计算预测值和损失;
#后面定义 :val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False)
val_loss_logger.log(state['epoch'], meter_loss.value()[0]) #计算验证集损失和峰值信噪比的对数,以迭代次数为底
val_psnr_logger.log(state['epoch'], meter_psnr.value())
print('[Epoch %d] Val Loss: %.4f (PSNR: %.2f db)' % (state['epoch'], meter_loss.value()[0], meter_psnr.value())) #打印验证的迭代次数对应的 损失 峰值信噪比
torch.save(model.state_dict(), 'epochs/epoch_%d_%d.pt' % (UPSCALE_FACTOR, state['epoch'])) #state_dict():返回一个包含模块的整个状态的字典(返回model字典里是:前向卷积和激活函数)
#save():将对象(返回的model字典)保存到磁盘文件中。
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train Super Resolution') ##用argparse包中ArgumentParser类生成一个参数解析器 parser ;description是描述参数解析器的作用:训练超分
parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor') #增加一个 upscale_factor参数 :超分上采样因数
parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number') #增加一个 num_epochs参数:超分迭代次数
opt = parser.parse_args() ##用对象parse_args()获取参数解析器的参数到 opt里
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),target_transform=transforms.ToTensor())
#调用 data_utils.py 里的类 DatasetFromFolder:对数据进行处理、转换、剪切 得到 image(训练用的data) 和 target(标签# ) 后面赋值给dataset 进行封装batch_size 进行训练
val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),target_transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
"""
先简单的介绍一下DataLoader
它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),
该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
# torch.utils.data.DataLoader() 将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
torch.utils.data.DataLoader(
dataset,#数据加载
batch_size = 1,#批处理大小设置
shuffle = False,#是否进项洗牌操作
sampler = None,#指定数据加载中使用的索引/键的序列
batch_sampler = None,#和sampler类似
num_workers = 0,#是否进行多进程加载数据设置
collate_fn = None,#是否合并样本列表以形成一小批Tensor
pin_memory = False,#如果True,数据加载器会在返回之前将Tensors复制到CUDA固定内存
drop_last = False,#True如果数据集大小不能被批处理大小整除,则设置为删除最后一个不完整的批处理。
timeout = 0,#如果为正,则为从工作人员收集批处理的超时值
worker_init_fn = None )
"""
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False) #对训练集和验证集设置 batch_size 批量处理
model = Net(upscale_factor=UPSCALE_FACTOR) #前向传播
criterion = nn.MSELoss() #计算损失
if torch.cuda.is_available(): #把model 里的计算添加到cuda 中计算
model = model.cuda()
criterion = criterion.cuda()
print('# parameters:', sum(param.numel() for param in model.parameters()))
optimizer = optim.Adam(model.parameters(), lr=1e-2) #优化方法为adma; model.parameters :待优化的参数 ; lr:学习率
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1) #学习率变化:milestones表示学习率在 30 80 时候各变化一次 变化 gamma = 0.1 倍
engine = Engine() #封装test train val 的训练过程
meter_loss = tnt.meter.AverageValueMeter() #计算平均回归损失
meter_psnr = PSNRMeter() #获取峰值信噪比
train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'}) #可视化
train_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Train PSNR'})
val_loss_logger = VisdomPlotLogger('line', opts={'title': 'Val Loss'})
val_psnr_logger = VisdomPlotLogger('line', opts={'title': 'Val PSNR'})
engine.hooks['on_sample'] = on_sample #engine封装的hook,即训练的过程文档
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch
engine.train(processor, train_loader, maxepoch=NUM_EPOCHS, optimizer=optimizer) #engine封装的训练
结束!图片测试和视频测试都和 train.py 大同小异 供大家自己解读;
萌新第一次发博客求点赞 :)
本文地址:https://blog.csdn.net/qq_38477064/article/details/109641687