欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

超分辨率 测试集处理(基于插值下采样)

程序员文章站 2022-03-09 13:07:25
...

测试集生成

代码样例基于Set5 测试集 其他测试集类似。
在超分辨率训练集的训练过程中,为了节省训练时间,通常将训练集按照倍率提前生成。

方案

1.原图剪切:为了使得缩放后的图像与原图像保持像素上的一一对应关系,因此需要将原图的w,h 都能被缩放因子整除。一般超分辨率的缩放因子设置为【2,3,4,8】.所以提取最小公倍数 24.原图需要被剪切为w,h都能被24整除。
2.LR图像剪切:在这里使用的torchvision.transforms 包。

代码

from os import listdir
from os.path import join
from torchvision.transforms import Compose,  CenterCrop, Resize
from PIL import Image
import os

def is_imagefile(image):
    return any(image.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG',
                                                               '.JPEG','bmp','BMP'])

def calculate_valid_crop_size(crop_size,upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def hr_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
    ])

def lr_transform(crop_size):
    return Compose([
        Resize(crop_size, interpolation=Image.BICUBIC),
    ])


def produce_image(data_dir,scale):
    filename = [join(data_dir, x) for x in listdir(data_dir) if is_imagefile(x)]

    for x in filename:
        images_name = x.split('/')[-1]
        images_name = images_name.split('.')[0]
        x_image = Image.open(x)
        (w,h) = x_image.size
        print(w,h)
        nw = calculate_valid_crop_size(w,24)
        nh = calculate_valid_crop_size(h,24)
        hr_size = hr_transform((nh,nw))
        x_image = hr_size(x_image)
        print(x_image)
        save_image(x_image,scale,images_name)

def save_image(x_image,scale,images_name):
    output_lr_dir = 'TestDataSR/Set5/LR'
    output_hr_dir = 'TestDataSR/Set5/HR'

    x_image.save(os.path.join(output_hr_dir,images_name + '.bmp'))
    for s in scale:
        os.makedirs(
            os.path.join(output_lr_dir,'X{}'.format(s)),
            exist_ok= True
        )
        path = os.path.join(output_lr_dir,'X{}'.format(s) + '/' + images_name + '_X{}'.format(s) + '.bmp')
        (nw,nh) = x_image.size
        lr_size = lr_transform((nh // s, nw // s))
        xr_image = lr_size(x_image)
        xr_image.save(path)




data_dir = 'TestDataSR/Set5'
scale = [2,3,4,8]
produce_image(data_dir,scale)
相关标签: 超分辨率