欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

程序员文章站 2023-12-24 22:20:45
...

本篇blog的内容基于原始论文SRAGN-Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network(CVPR2017)和《生成对抗网络入门指南》第六章。完整代码及简析见文章末尾


一、 摘要:为什么要使用SRGAN

使用更深和更快的CNN已经对超分辨率(super-resolution)提升效果很好了,但是对图片上采样时候,应该怎么样提升精度?在本篇论文中,使用了GAN用于处理图像超精度SR。

这是第一个对放大四倍自然图像做超分辨率的框架。为了实现这个框架,作者改进了目标函数,使用RestNET来修复训练。

  1. adversarial loss由判别器训练原始图像和超精度图像的差异,使我们生成的图像更加接近自然图像。
  2. content loss由图像的视觉相似性生成,而不是像素空间的相似性。
  3. ResNET可以从下采样的图像恢复逼真的纹理。
  4. mean-opinion-score(MOS)测试作为图像效果的评判,最后的测试结果表明采用SRGAN获得的图像的MOS值比采用其他*的方法获得的图像的MOS值更加接近原始的高分辨图像。

 

二、 超分辨率SR的研究

超分辨率(SR)指的是由低分辨(LR)图像生成高分辨(HR)图像的技术。

目前被大多人采用的以最优化目标函数为基础的监督SR算法存在缺失图像高频纹理细节的问题,使生成的图像很模糊。这种算法大多以均方误差(MSE)为目标函数进行优化,在减小均方误差的同时又可以增大信噪比(PSNR)

但是MSE和PSNR值的高低并不能很好的表示视觉效果的好坏,PSNR最高也不能反映SR效果最好。 
 

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

在本篇论文中,提出SRGAN,使用ResNET来作为优化目标网络。与以前的研究不同的是,我们定义了一个全新的perceptual loss使用了VGGNet的高级特征图结构,然后结合判别器来判断高精度图片。下面是对4x上采样高精度的例子:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

 

三、 SRGAN结构

1. 实验目标:训练一个function G能够对给定的一个低精度LR的输入图像生成高精度HR对抗图像。

 

2. 结构

①生成器:在生成器使用一个前向反馈的CNN,对于训练数据采取SR-specific loss,并对生成器的参数SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network进行优化:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

这里 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是高精度训练图像,SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 的低精度版本(下采样),SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是生成器参数, SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是损失函数见下面目标函数。

 

在前馈网络中,使用ResNet的结构来训练输入的LR图像。

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

 

 

②判别器:根据原始GAN,这里我们同样做一个极小极大值函数。

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

这里 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是高精度训练图像,SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 的低精度版本(下采样)。

 

对于真实的HR图像和生成的SR样本训练判别器使用LeakyReLU,不使用最大池化操作。包含一个VGG19的网络。 

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

③目标函数:这里的 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是perceprtual loss fucntion,作为评估生成图像好坏的指标。

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

  • Content loss

Pixel-wise MSE loss
 

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

这里经常被作为优化目标使用在state-of-art项目的SR图像上。这里MSE的优化问题经常确实高频率的内容,所以经常会不满足处理平滑的纹理图像。

这里我们使用一个预训练的19层VGGNet(使用LeakyReLU,不使用最大池化操作):

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

这里 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是高精度训练图像,SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 是 SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 的低精度版本(下采样),SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network是VGGNet的维度,SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 指代在包含第j层CNN经过**后,在第i层最大池化层之前的VGG19Net。

 

  • Adversarail loss(GAN loss)

这里是常规的判别器对于生成图像的判别损失

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

 

在后面许多论文中都采用了以上的损失结构,特别是在GAN与艺术生成里面,content loss极为常见。

 

四、实验评估

MOS testing

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

 

五、实验代码

数据集地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/

1. 导入包及创建初始化超参数

import tensorflow as tf

import vgg19

import sys

sys.path.append('../')
import tfutil as t

tf.set_random_seed(777)  # reproducibility


class SRGAN:

    def __init__(self, s, batch_size=16, height=384, width=384, channel=3,
                 sample_num=1 * 1, sample_size=1,
                 df_dim=64, gf_dim=64, lr=1e-4, use_vgg19=True):

        """ Super-Resolution GAN Class
        # General Settings
        :param s: TF Session
        :param batch_size: training batch size, default 16
        :param height: input image height, default 384
        :param width: input image width, default 384
        :param channel: input image channel, default 3 (RGB)
        - in case of DIV2K-HR, image size is 384x384x3(HWC).

        # Output Settings
        :param sample_num: the number of output images, default 1
        :param sample_size: sample image size, default 1

        # For CNN model
        :param df_dim: discriminator filter, default 64
        :param gf_dim: generator filter, default 64

        # Training Option
        :param lr: learning rate, default 1e-4
        :param use_vgg19: using pre-trained vgg19 bottle-neck features, default False
        """

        self.s = s
        self.batch_size = batch_size

        self.height = height
        self.width = width
        self.channel = channel

        self.lr_image_shape = [None, self.height // 4, self.width // 4, self.channel]
        self.hr_image_shape = [None, self.height, self.width, self.channel]

        self.vgg_image_shape = [224, 224, 3]

        self.sample_num = sample_num
        self.sample_size = sample_size

        self.df_dim = df_dim
        self.gf_dim = gf_dim

        self.beta1 = 0.9
        self.beta2 = 0.999

        self.lr_decay_rate = 1e-1
        self.lr_low_boundary = 1e-5
        self.lr_update_step = 1e5
        self.lr_update_epoch = 1000

        self.vgg_mean = [103.939, 116.779, 123.68]

        # pre-defined
        self.d_real = 0.
        self.d_fake = 0.
        self.d_loss = 0.
        self.g_adv_loss = 0.
        self.g_cnt_loss = 0.
        self.g_loss = 0.
        self.psnr = 0.

        self.use_vgg19 = use_vgg19
        self.vgg19 = None

        self.g = None

        self.adv_scaling = 1e-3
        self.cnt_scaling = 1. / 12.75  # 6e-3

        self.d_op = None
        self.g_op = None
        self.g_init_op = None

        self.merged = None
        self.writer = None
        self.saver = None

        # Placeholders
        self.x_hr = tf.placeholder(tf.float32, shape=self.hr_image_shape, name="x-image-hr")  # (-1, 384, 384, 3)
        self.x_lr = tf.placeholder(tf.float32, shape=self.lr_image_shape, name="x-image-lr")  # (-1, 96, 96, 3)

        self.lr = tf.placeholder(tf.float32, name='lr')

        self.build_srgan()  # build SRGAN model

 

2. 构造生成器和判别器

①判别器:使用LeakyReLU,

    def discriminator(self, x, reuse=None):
        """
        # Following a network architecture referred in the paper
        :param x: Input images (-1, 384, 384, 3)
        :param reuse: re-usability
        :return: HR (High Resolution) or SR (Super Resolution) images
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            x = t.conv2d(x, self.df_dim, 3, 1, name='n64s1-1')
            x = tf.nn.leaky_relu(x)

            strides = [2, 1]
            filters = [1, 2, 2, 4, 4, 8, 8]

            for i, f in enumerate(filters):
                x = t.conv2d(x, f=f, k=3, s=strides[i % 2], name='n%ds%d-%d' % (f, strides[i % 2], i + 1))
                x = t.batch_norm(x, name='n%d-bn-%d' % (f, i + 1))
                x = tf.nn.leaky_relu(x)

            x = tf.layers.flatten(x)  # (-1, 96 * 96 * 64)

            x = t.dense(x, 1024, name='disc-fc-1')
            x = tf.nn.leaky_relu(x)

            x = t.dense(x, 1, name='disc-fc-2')
            # x = tf.nn.sigmoid(x)
            return x

②生成器

    def generator(self, x, reuse=None, is_train=True):
        """
        :param x: LR (Low Resolution) images, (-1, 96, 96, 3)
        :param reuse: scope re-usability
        :param is_train: is trainable, default True
        :return: SR (Super Resolution) images, (-1, 384, 384, 3)
        """

        with tf.variable_scope("generator", reuse=reuse):
            def residual_block(x, f, name="", _is_train=True):
                with tf.variable_scope(name):
                    shortcut = tf.identity(x, name='n64s1-shortcut')

                    x = t.conv2d(x, f, 3, 1, name="n64s1-1")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-1")
                    x = t.prelu(x, reuse=reuse, name='n64s1-prelu-1')
                    x = t.conv2d(x, f, 3, 1, name="n64s1-2")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-2")
                    x = tf.add(x, shortcut)

                    return x

            x = t.conv2d(x, self.gf_dim, 9, 1, name='n64s1-1')
            x = t.prelu(x, name='n64s1-prelu-1')

            skip_conn = tf.identity(x, name='skip_connection')

            # B residual blocks
            for i in range(1, 17):  # (1, 9)
                x = residual_block(x, self.gf_dim, name='b-residual_block_%d' % i, _is_train=is_train)

            x = t.conv2d(x, self.gf_dim, 3, 1, name='n64s1-3')
            x = t.batch_norm(x, is_train=is_train, name='n64s1-bn-3')

            x = tf.add(x, skip_conn)

            # sub-pixel conv2d blocks
            for i in range(1, 3):
                x = t.conv2d(x, self.gf_dim * 4, 3, 1, name='n256s1-%d' % (i + 2))
                x = t.sub_pixel_conv2d(x, f=None, s=2)
                x = t.prelu(x, name='n256s1-prelu-%d' % i)

            x = t.conv2d(x, self.channel, 9, 1, name='n3s1')  # (-1, 384, 384, 3)
            x = tf.nn.tanh(x)
            return x

 

3. 构造VGGNet

    def build_vgg19(self, x, reuse=None):
        with tf.variable_scope("vgg19", reuse=reuse):
            # image re-scaling
            x = tf.cast((x + 1) / 2, dtype=tf.float32)  # [-1, 1] to [0, 1]
            x = tf.cast(x * 255., dtype=tf.float32)     # [0, 1]  to [0, 255]

            r, g, b = tf.split(x, 3, 3)
            bgr = tf.concat([b - self.vgg_mean[0],
                             g - self.vgg_mean[1],
                             r - self.vgg_mean[2]], axis=3)

            self.vgg19 = vgg19.VGG19(bgr)

            net = self.vgg19.vgg19_net['conv5_4']

            return net  # last layer

 

4. 构造SRGAN模型

    def build_srgan(self):
        # Generator
        self.g = self.generator(self.x_lr)

        # Discriminator
        d_real = self.discriminator(self.x_hr)
        d_fake = self.discriminator(self.g, reuse=True)

        # Losses
        # d_real_loss = -tf.reduce_mean(t.safe_log(d_real))
        # d_fake_loss = -tf.reduce_mean(t.safe_log(1. - d_fake))
        d_real_loss = t.sce_loss(d_real, tf.ones_like(d_real))
        d_fake_loss = t.sce_loss(d_fake, tf.zeros_like(d_fake))
        self.d_loss = d_real_loss + d_fake_loss

        if self.use_vgg19:
            x_vgg_real = tf.image.resize_images(self.x_hr, size=self.vgg_image_shape[:2], align_corners=False)
            x_vgg_fake = tf.image.resize_images(self.g, size=self.vgg_image_shape[:2], align_corners=False)

            vgg_bottle_real = self.build_vgg19(x_vgg_real)
            vgg_bottle_fake = self.build_vgg19(x_vgg_fake, reuse=True)

            self.g_cnt_loss = self.cnt_scaling * t.mse_loss(vgg_bottle_fake, vgg_bottle_real, self.batch_size,
                                                            is_mean=True)
        else:
            self.g_cnt_loss = t.mse_loss(self.g, self.x_hr, self.batch_size, is_mean=True)

        # self.g_adv_loss = self.adv_scaling * tf.reduce_mean(-1. * t.safe_log(d_fake))
        self.g_adv_loss = self.adv_scaling * t.sce_loss(d_fake, tf.ones_like(d_fake))
        self.g_loss = self.g_adv_loss + self.g_cnt_loss

        def inverse_transform(img):
            return (img + 1.) * 127.5

        # calculate PSNR
        g, x_hr = inverse_transform(self.g), inverse_transform(self.x_hr)
        self.psnr = t.psnr_loss(g, x_hr, self.batch_size)

        # Summary
        tf.summary.scalar("loss/d_real_loss", d_real_loss)
        tf.summary.scalar("loss/d_fake_loss", d_fake_loss)
        tf.summary.scalar("loss/d_loss", self.d_loss)
        tf.summary.scalar("loss/g_cnt_loss", self.g_cnt_loss)
        tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
        tf.summary.scalar("loss/g_loss", self.g_loss)
        tf.summary.scalar("misc/psnr", self.psnr)
        tf.summary.scalar("misc/lr", self.lr)

        # Optimizer
        t_vars = tf.trainable_variables()
        d_params = [v for v in t_vars if v.name.startswith('d')]
        g_params = [v for v in t_vars if v.name.startswith('g')]

        self.d_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.d_loss,
                                                                                        var_list=d_params)
        self.g_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_loss,
                                                                                        var_list=g_params)

        # pre-train
        self.g_init_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                                beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_cnt_loss,
                                                                                             var_list=g_params)

        # Merge summary
        self.merged = tf.summary.merge_all()

        # Model saver
        self.saver = tf.train.Saver(max_to_keep=2)
        self.writer = tf.summary.FileWriter('./model/', self.s.graph)

 

5. 主函数

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np

import sys
import time

sys.path.append('../')
import image_utils as iu
from datasets import Div2KDataSet as DataSet


np.random.seed(1337)


results = {
    'output': './gen_img/',
    'model': './model/SRGAN-model.ckpt'
}

train_step = {
    'batch_size': 16,
    'init_epochs': 100,
    'train_epochs': 1501,
    'global_step': 200001,
    'logging_interval': 100,
}


def main():
    start_time = time.time()  # Clocking start

    # Div2K - Track 1: Bicubic downscaling - x4 DataSet load
    """
    ds = DataSet(ds_path="/home/zero/hdd/DataSet/DIV2K/",
                 ds_name="X4",
                 use_save=True,
                 save_type="to_h5",
                 save_file_name="/home/zero/hdd/DataSet/DIV2K/DIV2K",
                 use_img_scale=True)
    """
    ds = DataSet(ds_hr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-hr.h5",
                 ds_lr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-lr.h5",
                 use_img_scale=True)

    hr, lr = ds.hr_images, ds.lr_images

    print("[+] Loaded HR image ", hr.shape)
    print("[+] Loaded LR image ", lr.shape)

    # GPU configure
    gpu_config = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config)

    with tf.Session(config=config) as s:
        with tf.device("/gpu:1"):  # Change
            # SRGAN Model
            model = SRGAN(s, batch_size=train_step['batch_size'],
                                use_vgg19=False)

        # Initializing
        s.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        ckpt = tf.train.get_checkpoint_state('./model/')
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            model.saver.restore(s, ckpt.model_checkpoint_path)

            global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("[+] global step : %d" % global_step, " successfully loaded")
        else:
            global_step = 0
            print('[-] No checkpoint file found')

        start_epoch = global_step // (ds.n_images // train_step['batch_size'])

        rnd = np.random.randint(0, ds.n_images)
        sample_x_hr, sample_x_lr = hr[rnd], lr[rnd]

        sample_x_hr, sample_x_lr = \
            np.reshape(sample_x_hr, [1] + model.hr_image_shape[1:]), \
            np.reshape(sample_x_lr, [1] + model.lr_image_shape[1:])

        # Export real image
        # valid_image_height = model.sample_size
        # valid_image_width = model.sample_size
        sample_hr_dir, sample_lr_dir = results['output'] + 'valid_hr.png', results['output'] + 'valid_lr.png'

        # Generated image save
        iu.save_images(sample_x_hr,
                       size=[1, 1],
                       image_path=sample_hr_dir,
                       inv_type='127')

        iu.save_images(sample_x_lr,
                       size=[1, 1],
                       image_path=sample_lr_dir,
                       inv_type='127')

        learning_rate = 1e-4
        for epoch in range(start_epoch, train_step['train_epochs']):
            pointer = 0
            for i in range(ds.n_images // train_step['batch_size']):
                start = pointer
                pointer += train_step['batch_size']

                if pointer > ds.n_images:  # if 1 epoch is ended
                    # Shuffle training DataSet
                    perm = np.arange(ds.n_images)
                    np.random.shuffle(perm)

                    hr, lr = hr[perm], lr[perm]

                    start = 0
                    pointer = train_step['batch_size']

                end = pointer

                batch_x_hr, batch_x_lr = hr[start:end], lr[start:end]

                # reshape
                batch_x_hr = np.reshape(batch_x_hr, [train_step['batch_size']] + model.hr_image_shape[1:])
                batch_x_lr = np.reshape(batch_x_lr, [train_step['batch_size']] + model.lr_image_shape[1:])

                # Update Only G network
                d_loss, g_loss, g_init_loss = 0., 0., 0.
                if epoch <= train_step['init_epochs']:
                    _, g_init_loss = s.run([model.g_init_op, model.g_cnt_loss],
                                           feed_dict={
                                               model.x_hr: batch_x_hr,
                                               model.x_lr: batch_x_lr,
                                               model.lr: learning_rate,
                                           })
                # Update G/D network
                else:
                    _, d_loss = s.run([model.d_op, model.d_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                    _, g_loss = s.run([model.g_op, model.g_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                if i % train_step['logging_interval'] == 0:
                    # Print loss
                    if epoch <= train_step['init_epochs']:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " MSE loss : {:.8f}".format(g_init_loss))
                    else:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " D loss : {:.8f}".format(d_loss),
                              " G loss : {:.8f}".format(g_loss))

                        summary = s.run(model.merged,
                                        feed_dict={
                                            model.x_hr: batch_x_hr,
                                            model.x_lr: batch_x_lr,
                                            model.lr: learning_rate,
                                        })

                        # Summary saver
                        model.writer.add_summary(summary, global_step)

                    # Training G model with sample image and noise
                    sample_x_lr = np.reshape(sample_x_lr, [model.sample_num] + model.lr_image_shape[1:])
                    samples = s.run(model.g,
                                    feed_dict={
                                        model.x_lr: sample_x_lr,
                                        model.lr: learning_rate,
                                    })

                    # Export image generated by model G
                    # sample_image_height = model.output_height
                    # sample_image_width = model.output_width
                    sample_dir = results['output'] + 'train_{:08d}.png'.format(global_step)

                    # Generated image save
                    iu.save_images(samples,
                                   size=[1, 1],
                                   image_path=sample_dir,
                                   inv_type='127')

                    # Model save
                    model.saver.save(s, results['model'], global_step)

                # Learning Rate update
                if epoch and epoch % model.lr_update_epoch == 0:
                    learning_rate *= model.lr_decay_rate
                    learning_rate = max(learning_rate, model.lr_low_boundary)

                global_step += 1

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))

    # Close tf.Session
    s.close()


if __name__ == '__main__':
    main()

 

6. 运行结果(生成图像)

初始图像LR:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

训练过程图像:0-55000

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial NetworkSRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial NetworkSRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial NetworkSRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

生成高精度HR图像:

SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

 

完整代码

import tensorflow as tf

import vgg19

import sys

sys.path.append('../')
import tfutil as t


tf.set_random_seed(777)  # reproducibility


class SRGAN:

    def __init__(self, s, batch_size=16, height=384, width=384, channel=3,
                 sample_num=1 * 1, sample_size=1,
                 df_dim=64, gf_dim=64, lr=1e-4, use_vgg19=True):

        """ Super-Resolution GAN Class
        # General Settings
        :param s: TF Session
        :param batch_size: training batch size, default 16
        :param height: input image height, default 384
        :param width: input image width, default 384
        :param channel: input image channel, default 3 (RGB)
        - in case of DIV2K-HR, image size is 384x384x3(HWC).

        # Output Settings
        :param sample_num: the number of output images, default 1
        :param sample_size: sample image size, default 1

        # For CNN model
        :param df_dim: discriminator filter, default 64
        :param gf_dim: generator filter, default 64

        # Training Option
        :param lr: learning rate, default 1e-4
        :param use_vgg19: using pre-trained vgg19 bottle-neck features, default False
        """

        self.s = s
        self.batch_size = batch_size

        self.height = height
        self.width = width
        self.channel = channel

        self.lr_image_shape = [None, self.height // 4, self.width // 4, self.channel]
        self.hr_image_shape = [None, self.height, self.width, self.channel]

        self.vgg_image_shape = [224, 224, 3]

        self.sample_num = sample_num
        self.sample_size = sample_size

        self.df_dim = df_dim
        self.gf_dim = gf_dim

        self.beta1 = 0.9
        self.beta2 = 0.999

        self.lr_decay_rate = 1e-1
        self.lr_low_boundary = 1e-5
        self.lr_update_step = 1e5
        self.lr_update_epoch = 1000

        self.vgg_mean = [103.939, 116.779, 123.68]

        # pre-defined
        self.d_real = 0.
        self.d_fake = 0.
        self.d_loss = 0.
        self.g_adv_loss = 0.
        self.g_cnt_loss = 0.
        self.g_loss = 0.
        self.psnr = 0.

        self.use_vgg19 = use_vgg19
        self.vgg19 = None

        self.g = None

        self.adv_scaling = 1e-3
        self.cnt_scaling = 1. / 12.75  # 6e-3

        self.d_op = None
        self.g_op = None
        self.g_init_op = None

        self.merged = None
        self.writer = None
        self.saver = None

        # Placeholders
        self.x_hr = tf.placeholder(tf.float32, shape=self.hr_image_shape, name="x-image-hr")  # (-1, 384, 384, 3)
        self.x_lr = tf.placeholder(tf.float32, shape=self.lr_image_shape, name="x-image-lr")  # (-1, 96, 96, 3)

        self.lr = tf.placeholder(tf.float32, name='lr')

        self.build_srgan()  # build SRGAN model

    def discriminator(self, x, reuse=None):
        """
        # Following a network architecture referred in the paper
        :param x: Input images (-1, 384, 384, 3)
        :param reuse: re-usability
        :return: HR (High Resolution) or SR (Super Resolution) images
        """
        with tf.variable_scope("discriminator", reuse=reuse):
            x = t.conv2d(x, self.df_dim, 3, 1, name='n64s1-1')
            x = tf.nn.leaky_relu(x)

            strides = [2, 1]
            filters = [1, 2, 2, 4, 4, 8, 8]

            for i, f in enumerate(filters):
                x = t.conv2d(x, f=f, k=3, s=strides[i % 2], name='n%ds%d-%d' % (f, strides[i % 2], i + 1))
                x = t.batch_norm(x, name='n%d-bn-%d' % (f, i + 1))
                x = tf.nn.leaky_relu(x)

            x = tf.layers.flatten(x)  # (-1, 96 * 96 * 64)

            x = t.dense(x, 1024, name='disc-fc-1')
            x = tf.nn.leaky_relu(x)

            x = t.dense(x, 1, name='disc-fc-2')
            # x = tf.nn.sigmoid(x)
            return x

    def generator(self, x, reuse=None, is_train=True):
        """
        :param x: LR (Low Resolution) images, (-1, 96, 96, 3)
        :param reuse: scope re-usability
        :param is_train: is trainable, default True
        :return: SR (Super Resolution) images, (-1, 384, 384, 3)
        """

        with tf.variable_scope("generator", reuse=reuse):
            def residual_block(x, f, name="", _is_train=True):
                with tf.variable_scope(name):
                    shortcut = tf.identity(x, name='n64s1-shortcut')

                    x = t.conv2d(x, f, 3, 1, name="n64s1-1")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-1")
                    x = t.prelu(x, reuse=reuse, name='n64s1-prelu-1')
                    x = t.conv2d(x, f, 3, 1, name="n64s1-2")
                    x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-2")
                    x = tf.add(x, shortcut)

                    return x

            x = t.conv2d(x, self.gf_dim, 9, 1, name='n64s1-1')
            x = t.prelu(x, name='n64s1-prelu-1')

            skip_conn = tf.identity(x, name='skip_connection')

            # B residual blocks
            for i in range(1, 17):  # (1, 9)
                x = residual_block(x, self.gf_dim, name='b-residual_block_%d' % i, _is_train=is_train)

            x = t.conv2d(x, self.gf_dim, 3, 1, name='n64s1-3')
            x = t.batch_norm(x, is_train=is_train, name='n64s1-bn-3')

            x = tf.add(x, skip_conn)

            # sub-pixel conv2d blocks
            for i in range(1, 3):
                x = t.conv2d(x, self.gf_dim * 4, 3, 1, name='n256s1-%d' % (i + 2))
                x = t.sub_pixel_conv2d(x, f=None, s=2)
                x = t.prelu(x, name='n256s1-prelu-%d' % i)

            x = t.conv2d(x, self.channel, 9, 1, name='n3s1')  # (-1, 384, 384, 3)
            x = tf.nn.tanh(x)
            return x

    def build_vgg19(self, x, reuse=None):
        with tf.variable_scope("vgg19", reuse=reuse):
            # image re-scaling
            x = tf.cast((x + 1) / 2, dtype=tf.float32)  # [-1, 1] to [0, 1]
            x = tf.cast(x * 255., dtype=tf.float32)     # [0, 1]  to [0, 255]

            r, g, b = tf.split(x, 3, 3)
            bgr = tf.concat([b - self.vgg_mean[0],
                             g - self.vgg_mean[1],
                             r - self.vgg_mean[2]], axis=3)

            self.vgg19 = vgg19.VGG19(bgr)

            net = self.vgg19.vgg19_net['conv5_4']

            return net  # last layer

    def build_srgan(self):
        # Generator
        self.g = self.generator(self.x_lr)

        # Discriminator
        d_real = self.discriminator(self.x_hr)
        d_fake = self.discriminator(self.g, reuse=True)

        # Losses
        # d_real_loss = -tf.reduce_mean(t.safe_log(d_real))
        # d_fake_loss = -tf.reduce_mean(t.safe_log(1. - d_fake))
        d_real_loss = t.sce_loss(d_real, tf.ones_like(d_real))
        d_fake_loss = t.sce_loss(d_fake, tf.zeros_like(d_fake))
        self.d_loss = d_real_loss + d_fake_loss

        if self.use_vgg19:
            x_vgg_real = tf.image.resize_images(self.x_hr, size=self.vgg_image_shape[:2], align_corners=False)
            x_vgg_fake = tf.image.resize_images(self.g, size=self.vgg_image_shape[:2], align_corners=False)

            vgg_bottle_real = self.build_vgg19(x_vgg_real)
            vgg_bottle_fake = self.build_vgg19(x_vgg_fake, reuse=True)

            self.g_cnt_loss = self.cnt_scaling * t.mse_loss(vgg_bottle_fake, vgg_bottle_real, self.batch_size,
                                                            is_mean=True)
        else:
            self.g_cnt_loss = t.mse_loss(self.g, self.x_hr, self.batch_size, is_mean=True)

        # self.g_adv_loss = self.adv_scaling * tf.reduce_mean(-1. * t.safe_log(d_fake))
        self.g_adv_loss = self.adv_scaling * t.sce_loss(d_fake, tf.ones_like(d_fake))
        self.g_loss = self.g_adv_loss + self.g_cnt_loss

        def inverse_transform(img):
            return (img + 1.) * 127.5

        # calculate PSNR
        g, x_hr = inverse_transform(self.g), inverse_transform(self.x_hr)
        self.psnr = t.psnr_loss(g, x_hr, self.batch_size)

        # Summary
        tf.summary.scalar("loss/d_real_loss", d_real_loss)
        tf.summary.scalar("loss/d_fake_loss", d_fake_loss)
        tf.summary.scalar("loss/d_loss", self.d_loss)
        tf.summary.scalar("loss/g_cnt_loss", self.g_cnt_loss)
        tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
        tf.summary.scalar("loss/g_loss", self.g_loss)
        tf.summary.scalar("misc/psnr", self.psnr)
        tf.summary.scalar("misc/lr", self.lr)

        # Optimizer
        t_vars = tf.trainable_variables()
        d_params = [v for v in t_vars if v.name.startswith('d')]
        g_params = [v for v in t_vars if v.name.startswith('g')]

        self.d_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.d_loss,
                                                                                        var_list=d_params)
        self.g_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                           beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_loss,
                                                                                        var_list=g_params)

        # pre-train
        self.g_init_op = tf.train.AdamOptimizer(learning_rate=self.lr,
                                                beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_cnt_loss,
                                                                                             var_list=g_params)

        # Merge summary
        self.merged = tf.summary.merge_all()

        # Model saver
        self.saver = tf.train.Saver(max_to_keep=2)
        self.writer = tf.summary.FileWriter('./model/', self.s.graph)
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np

import sys
import time

sys.path.append('../')
import image_utils as iu
from datasets import Div2KDataSet as DataSet


np.random.seed(1337)


results = {
    'output': './gen_img/',
    'model': './model/SRGAN-model.ckpt'
}

train_step = {
    'batch_size': 16,
    'init_epochs': 100,
    'train_epochs': 1501,
    'global_step': 200001,
    'logging_interval': 100,
}


def main():
    start_time = time.time()  # Clocking start

    # Div2K - Track 1: Bicubic downscaling - x4 DataSet load
    """
    ds = DataSet(ds_path="/home/zero/hdd/DataSet/DIV2K/",
                 ds_name="X4",
                 use_save=True,
                 save_type="to_h5",
                 save_file_name="/home/zero/hdd/DataSet/DIV2K/DIV2K",
                 use_img_scale=True)
    """
    ds = DataSet(ds_hr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-hr.h5",
                 ds_lr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-lr.h5",
                 use_img_scale=True)

    hr, lr = ds.hr_images, ds.lr_images

    print("[+] Loaded HR image ", hr.shape)
    print("[+] Loaded LR image ", lr.shape)

    # GPU configure
    gpu_config = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config)

    with tf.Session(config=config) as s:
        with tf.device("/gpu:1"):  # Change
            # SRGAN Model
            model = SRGAN(s, batch_size=train_step['batch_size'],
                                use_vgg19=False)

        # Initializing
        s.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        ckpt = tf.train.get_checkpoint_state('./model/')
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            model.saver.restore(s, ckpt.model_checkpoint_path)

            global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("[+] global step : %d" % global_step, " successfully loaded")
        else:
            global_step = 0
            print('[-] No checkpoint file found')

        start_epoch = global_step // (ds.n_images // train_step['batch_size'])

        rnd = np.random.randint(0, ds.n_images)
        sample_x_hr, sample_x_lr = hr[rnd], lr[rnd]

        sample_x_hr, sample_x_lr = \
            np.reshape(sample_x_hr, [1] + model.hr_image_shape[1:]), \
            np.reshape(sample_x_lr, [1] + model.lr_image_shape[1:])

        # Export real image
        # valid_image_height = model.sample_size
        # valid_image_width = model.sample_size
        sample_hr_dir, sample_lr_dir = results['output'] + 'valid_hr.png', results['output'] + 'valid_lr.png'

        # Generated image save
        iu.save_images(sample_x_hr,
                       size=[1, 1],
                       image_path=sample_hr_dir,
                       inv_type='127')

        iu.save_images(sample_x_lr,
                       size=[1, 1],
                       image_path=sample_lr_dir,
                       inv_type='127')

        learning_rate = 1e-4
        for epoch in range(start_epoch, train_step['train_epochs']):
            pointer = 0
            for i in range(ds.n_images // train_step['batch_size']):
                start = pointer
                pointer += train_step['batch_size']

                if pointer > ds.n_images:  # if 1 epoch is ended
                    # Shuffle training DataSet
                    perm = np.arange(ds.n_images)
                    np.random.shuffle(perm)

                    hr, lr = hr[perm], lr[perm]

                    start = 0
                    pointer = train_step['batch_size']

                end = pointer

                batch_x_hr, batch_x_lr = hr[start:end], lr[start:end]

                # reshape
                batch_x_hr = np.reshape(batch_x_hr, [train_step['batch_size']] + model.hr_image_shape[1:])
                batch_x_lr = np.reshape(batch_x_lr, [train_step['batch_size']] + model.lr_image_shape[1:])

                # Update Only G network
                d_loss, g_loss, g_init_loss = 0., 0., 0.
                if epoch <= train_step['init_epochs']:
                    _, g_init_loss = s.run([model.g_init_op, model.g_cnt_loss],
                                           feed_dict={
                                               model.x_hr: batch_x_hr,
                                               model.x_lr: batch_x_lr,
                                               model.lr: learning_rate,
                                           })
                # Update G/D network
                else:
                    _, d_loss = s.run([model.d_op, model.d_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                    _, g_loss = s.run([model.g_op, model.g_loss],
                                      feed_dict={
                                          model.x_hr: batch_x_hr,
                                          model.x_lr: batch_x_lr,
                                          model.lr: learning_rate,
                                      })

                if i % train_step['logging_interval'] == 0:
                    # Print loss
                    if epoch <= train_step['init_epochs']:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " MSE loss : {:.8f}".format(g_init_loss))
                    else:
                        print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
                              " D loss : {:.8f}".format(d_loss),
                              " G loss : {:.8f}".format(g_loss))

                        summary = s.run(model.merged,
                                        feed_dict={
                                            model.x_hr: batch_x_hr,
                                            model.x_lr: batch_x_lr,
                                            model.lr: learning_rate,
                                        })

                        # Summary saver
                        model.writer.add_summary(summary, global_step)

                    # Training G model with sample image and noise
                    sample_x_lr = np.reshape(sample_x_lr, [model.sample_num] + model.lr_image_shape[1:])
                    samples = s.run(model.g,
                                    feed_dict={
                                        model.x_lr: sample_x_lr,
                                        model.lr: learning_rate,
                                    })

                    # Export image generated by model G
                    # sample_image_height = model.output_height
                    # sample_image_width = model.output_width
                    sample_dir = results['output'] + 'train_{:08d}.png'.format(global_step)

                    # Generated image save
                    iu.save_images(samples,
                                   size=[1, 1],
                                   image_path=sample_dir,
                                   inv_type='127')

                    # Model save
                    model.saver.save(s, results['model'], global_step)

                # Learning Rate update
                if epoch and epoch % model.lr_update_epoch == 0:
                    learning_rate *= model.lr_decay_rate
                    learning_rate = max(learning_rate, model.lr_low_boundary)

                global_step += 1

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))

    # Close tf.Session
    s.close()


if __name__ == '__main__':
    main()

 

上一篇:

下一篇: