SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
程序员文章站
2023-12-24 22:20:45
...
本篇blog的内容基于原始论文SRAGN-Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network(CVPR2017)和《生成对抗网络入门指南》第六章。完整代码及简析见文章末尾
一、 摘要:为什么要使用SRGAN
使用更深和更快的CNN已经对超分辨率(super-resolution)提升效果很好了,但是对图片上采样时候,应该怎么样提升精度?在本篇论文中,使用了GAN用于处理图像超精度SR。
这是第一个对放大四倍自然图像做超分辨率的框架。为了实现这个框架,作者改进了目标函数,使用RestNET来修复训练。
- adversarial loss由判别器训练原始图像和超精度图像的差异,使我们生成的图像更加接近自然图像。
- content loss由图像的视觉相似性生成,而不是像素空间的相似性。
- ResNET可以从下采样的图像恢复逼真的纹理。
- mean-opinion-score(MOS)测试作为图像效果的评判,最后的测试结果表明采用SRGAN获得的图像的MOS值比采用其他*的方法获得的图像的MOS值更加接近原始的高分辨图像。
二、 超分辨率SR的研究
超分辨率(SR)指的是由低分辨(LR)图像生成高分辨(HR)图像的技术。
目前被大多人采用的以最优化目标函数为基础的监督SR算法存在缺失图像高频纹理细节的问题,使生成的图像很模糊。这种算法大多以均方误差(MSE)为目标函数进行优化,在减小均方误差的同时又可以增大信噪比(PSNR)。
但是MSE和PSNR值的高低并不能很好的表示视觉效果的好坏,PSNR最高也不能反映SR效果最好。
在本篇论文中,提出SRGAN,使用ResNET来作为优化目标网络。与以前的研究不同的是,我们定义了一个全新的perceptual loss使用了VGGNet的高级特征图结构,然后结合判别器来判断高精度图片。下面是对4x上采样高精度的例子:
三、 SRGAN结构
1. 实验目标:训练一个function G能够对给定的一个低精度LR的输入图像生成高精度HR对抗图像。
2. 结构
①生成器:在生成器使用一个前向反馈的CNN,对于训练数据采取SR-specific loss,并对生成器的参数进行优化:
这里 是高精度训练图像, 是 的低精度版本(下采样), 是生成器参数, 是损失函数见下面目标函数。
在前馈网络中,使用ResNet的结构来训练输入的LR图像。
②判别器:根据原始GAN,这里我们同样做一个极小极大值函数。
这里 是高精度训练图像, 是 的低精度版本(下采样)。
对于真实的HR图像和生成的SR样本训练判别器使用LeakyReLU,不使用最大池化操作。包含一个VGG19的网络。
③目标函数:这里的 是perceprtual loss fucntion,作为评估生成图像好坏的指标。
-
Content loss
Pixel-wise MSE loss
这里经常被作为优化目标使用在state-of-art项目的SR图像上。这里MSE的优化问题经常确实高频率的内容,所以经常会不满足处理平滑的纹理图像。
这里我们使用一个预训练的19层VGGNet(使用LeakyReLU,不使用最大池化操作):
这里 是高精度训练图像, 是 的低精度版本(下采样),是VGGNet的维度, 指代在包含第j层CNN经过**后,在第i层最大池化层之前的VGG19Net。
-
Adversarail loss(GAN loss)
这里是常规的判别器对于生成图像的判别损失
在后面许多论文中都采用了以上的损失结构,特别是在GAN与艺术生成里面,content loss极为常见。
四、实验评估
MOS testing
五、实验代码
数据集地址:https://data.vision.ee.ethz.ch/cvl/DIV2K/
1. 导入包及创建初始化超参数
import tensorflow as tf
import vgg19
import sys
sys.path.append('../')
import tfutil as t
tf.set_random_seed(777) # reproducibility
class SRGAN:
def __init__(self, s, batch_size=16, height=384, width=384, channel=3,
sample_num=1 * 1, sample_size=1,
df_dim=64, gf_dim=64, lr=1e-4, use_vgg19=True):
""" Super-Resolution GAN Class
# General Settings
:param s: TF Session
:param batch_size: training batch size, default 16
:param height: input image height, default 384
:param width: input image width, default 384
:param channel: input image channel, default 3 (RGB)
- in case of DIV2K-HR, image size is 384x384x3(HWC).
# Output Settings
:param sample_num: the number of output images, default 1
:param sample_size: sample image size, default 1
# For CNN model
:param df_dim: discriminator filter, default 64
:param gf_dim: generator filter, default 64
# Training Option
:param lr: learning rate, default 1e-4
:param use_vgg19: using pre-trained vgg19 bottle-neck features, default False
"""
self.s = s
self.batch_size = batch_size
self.height = height
self.width = width
self.channel = channel
self.lr_image_shape = [None, self.height // 4, self.width // 4, self.channel]
self.hr_image_shape = [None, self.height, self.width, self.channel]
self.vgg_image_shape = [224, 224, 3]
self.sample_num = sample_num
self.sample_size = sample_size
self.df_dim = df_dim
self.gf_dim = gf_dim
self.beta1 = 0.9
self.beta2 = 0.999
self.lr_decay_rate = 1e-1
self.lr_low_boundary = 1e-5
self.lr_update_step = 1e5
self.lr_update_epoch = 1000
self.vgg_mean = [103.939, 116.779, 123.68]
# pre-defined
self.d_real = 0.
self.d_fake = 0.
self.d_loss = 0.
self.g_adv_loss = 0.
self.g_cnt_loss = 0.
self.g_loss = 0.
self.psnr = 0.
self.use_vgg19 = use_vgg19
self.vgg19 = None
self.g = None
self.adv_scaling = 1e-3
self.cnt_scaling = 1. / 12.75 # 6e-3
self.d_op = None
self.g_op = None
self.g_init_op = None
self.merged = None
self.writer = None
self.saver = None
# Placeholders
self.x_hr = tf.placeholder(tf.float32, shape=self.hr_image_shape, name="x-image-hr") # (-1, 384, 384, 3)
self.x_lr = tf.placeholder(tf.float32, shape=self.lr_image_shape, name="x-image-lr") # (-1, 96, 96, 3)
self.lr = tf.placeholder(tf.float32, name='lr')
self.build_srgan() # build SRGAN model
2. 构造生成器和判别器
①判别器:使用LeakyReLU,
def discriminator(self, x, reuse=None):
"""
# Following a network architecture referred in the paper
:param x: Input images (-1, 384, 384, 3)
:param reuse: re-usability
:return: HR (High Resolution) or SR (Super Resolution) images
"""
with tf.variable_scope("discriminator", reuse=reuse):
x = t.conv2d(x, self.df_dim, 3, 1, name='n64s1-1')
x = tf.nn.leaky_relu(x)
strides = [2, 1]
filters = [1, 2, 2, 4, 4, 8, 8]
for i, f in enumerate(filters):
x = t.conv2d(x, f=f, k=3, s=strides[i % 2], name='n%ds%d-%d' % (f, strides[i % 2], i + 1))
x = t.batch_norm(x, name='n%d-bn-%d' % (f, i + 1))
x = tf.nn.leaky_relu(x)
x = tf.layers.flatten(x) # (-1, 96 * 96 * 64)
x = t.dense(x, 1024, name='disc-fc-1')
x = tf.nn.leaky_relu(x)
x = t.dense(x, 1, name='disc-fc-2')
# x = tf.nn.sigmoid(x)
return x
②生成器
def generator(self, x, reuse=None, is_train=True):
"""
:param x: LR (Low Resolution) images, (-1, 96, 96, 3)
:param reuse: scope re-usability
:param is_train: is trainable, default True
:return: SR (Super Resolution) images, (-1, 384, 384, 3)
"""
with tf.variable_scope("generator", reuse=reuse):
def residual_block(x, f, name="", _is_train=True):
with tf.variable_scope(name):
shortcut = tf.identity(x, name='n64s1-shortcut')
x = t.conv2d(x, f, 3, 1, name="n64s1-1")
x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-1")
x = t.prelu(x, reuse=reuse, name='n64s1-prelu-1')
x = t.conv2d(x, f, 3, 1, name="n64s1-2")
x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-2")
x = tf.add(x, shortcut)
return x
x = t.conv2d(x, self.gf_dim, 9, 1, name='n64s1-1')
x = t.prelu(x, name='n64s1-prelu-1')
skip_conn = tf.identity(x, name='skip_connection')
# B residual blocks
for i in range(1, 17): # (1, 9)
x = residual_block(x, self.gf_dim, name='b-residual_block_%d' % i, _is_train=is_train)
x = t.conv2d(x, self.gf_dim, 3, 1, name='n64s1-3')
x = t.batch_norm(x, is_train=is_train, name='n64s1-bn-3')
x = tf.add(x, skip_conn)
# sub-pixel conv2d blocks
for i in range(1, 3):
x = t.conv2d(x, self.gf_dim * 4, 3, 1, name='n256s1-%d' % (i + 2))
x = t.sub_pixel_conv2d(x, f=None, s=2)
x = t.prelu(x, name='n256s1-prelu-%d' % i)
x = t.conv2d(x, self.channel, 9, 1, name='n3s1') # (-1, 384, 384, 3)
x = tf.nn.tanh(x)
return x
3. 构造VGGNet
def build_vgg19(self, x, reuse=None):
with tf.variable_scope("vgg19", reuse=reuse):
# image re-scaling
x = tf.cast((x + 1) / 2, dtype=tf.float32) # [-1, 1] to [0, 1]
x = tf.cast(x * 255., dtype=tf.float32) # [0, 1] to [0, 255]
r, g, b = tf.split(x, 3, 3)
bgr = tf.concat([b - self.vgg_mean[0],
g - self.vgg_mean[1],
r - self.vgg_mean[2]], axis=3)
self.vgg19 = vgg19.VGG19(bgr)
net = self.vgg19.vgg19_net['conv5_4']
return net # last layer
4. 构造SRGAN模型
def build_srgan(self):
# Generator
self.g = self.generator(self.x_lr)
# Discriminator
d_real = self.discriminator(self.x_hr)
d_fake = self.discriminator(self.g, reuse=True)
# Losses
# d_real_loss = -tf.reduce_mean(t.safe_log(d_real))
# d_fake_loss = -tf.reduce_mean(t.safe_log(1. - d_fake))
d_real_loss = t.sce_loss(d_real, tf.ones_like(d_real))
d_fake_loss = t.sce_loss(d_fake, tf.zeros_like(d_fake))
self.d_loss = d_real_loss + d_fake_loss
if self.use_vgg19:
x_vgg_real = tf.image.resize_images(self.x_hr, size=self.vgg_image_shape[:2], align_corners=False)
x_vgg_fake = tf.image.resize_images(self.g, size=self.vgg_image_shape[:2], align_corners=False)
vgg_bottle_real = self.build_vgg19(x_vgg_real)
vgg_bottle_fake = self.build_vgg19(x_vgg_fake, reuse=True)
self.g_cnt_loss = self.cnt_scaling * t.mse_loss(vgg_bottle_fake, vgg_bottle_real, self.batch_size,
is_mean=True)
else:
self.g_cnt_loss = t.mse_loss(self.g, self.x_hr, self.batch_size, is_mean=True)
# self.g_adv_loss = self.adv_scaling * tf.reduce_mean(-1. * t.safe_log(d_fake))
self.g_adv_loss = self.adv_scaling * t.sce_loss(d_fake, tf.ones_like(d_fake))
self.g_loss = self.g_adv_loss + self.g_cnt_loss
def inverse_transform(img):
return (img + 1.) * 127.5
# calculate PSNR
g, x_hr = inverse_transform(self.g), inverse_transform(self.x_hr)
self.psnr = t.psnr_loss(g, x_hr, self.batch_size)
# Summary
tf.summary.scalar("loss/d_real_loss", d_real_loss)
tf.summary.scalar("loss/d_fake_loss", d_fake_loss)
tf.summary.scalar("loss/d_loss", self.d_loss)
tf.summary.scalar("loss/g_cnt_loss", self.g_cnt_loss)
tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
tf.summary.scalar("loss/g_loss", self.g_loss)
tf.summary.scalar("misc/psnr", self.psnr)
tf.summary.scalar("misc/lr", self.lr)
# Optimizer
t_vars = tf.trainable_variables()
d_params = [v for v in t_vars if v.name.startswith('d')]
g_params = [v for v in t_vars if v.name.startswith('g')]
self.d_op = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2).minimize(loss=self.d_loss,
var_list=d_params)
self.g_op = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_loss,
var_list=g_params)
# pre-train
self.g_init_op = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_cnt_loss,
var_list=g_params)
# Merge summary
self.merged = tf.summary.merge_all()
# Model saver
self.saver = tf.train.Saver(max_to_keep=2)
self.writer = tf.summary.FileWriter('./model/', self.s.graph)
5. 主函数
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import sys
import time
sys.path.append('../')
import image_utils as iu
from datasets import Div2KDataSet as DataSet
np.random.seed(1337)
results = {
'output': './gen_img/',
'model': './model/SRGAN-model.ckpt'
}
train_step = {
'batch_size': 16,
'init_epochs': 100,
'train_epochs': 1501,
'global_step': 200001,
'logging_interval': 100,
}
def main():
start_time = time.time() # Clocking start
# Div2K - Track 1: Bicubic downscaling - x4 DataSet load
"""
ds = DataSet(ds_path="/home/zero/hdd/DataSet/DIV2K/",
ds_name="X4",
use_save=True,
save_type="to_h5",
save_file_name="/home/zero/hdd/DataSet/DIV2K/DIV2K",
use_img_scale=True)
"""
ds = DataSet(ds_hr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-hr.h5",
ds_lr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-lr.h5",
use_img_scale=True)
hr, lr = ds.hr_images, ds.lr_images
print("[+] Loaded HR image ", hr.shape)
print("[+] Loaded LR image ", lr.shape)
# GPU configure
gpu_config = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config)
with tf.Session(config=config) as s:
with tf.device("/gpu:1"): # Change
# SRGAN Model
model = SRGAN(s, batch_size=train_step['batch_size'],
use_vgg19=False)
# Initializing
s.run(tf.global_variables_initializer())
# Load model & Graph & Weights
ckpt = tf.train.get_checkpoint_state('./model/')
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
model.saver.restore(s, ckpt.model_checkpoint_path)
global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
print("[+] global step : %d" % global_step, " successfully loaded")
else:
global_step = 0
print('[-] No checkpoint file found')
start_epoch = global_step // (ds.n_images // train_step['batch_size'])
rnd = np.random.randint(0, ds.n_images)
sample_x_hr, sample_x_lr = hr[rnd], lr[rnd]
sample_x_hr, sample_x_lr = \
np.reshape(sample_x_hr, [1] + model.hr_image_shape[1:]), \
np.reshape(sample_x_lr, [1] + model.lr_image_shape[1:])
# Export real image
# valid_image_height = model.sample_size
# valid_image_width = model.sample_size
sample_hr_dir, sample_lr_dir = results['output'] + 'valid_hr.png', results['output'] + 'valid_lr.png'
# Generated image save
iu.save_images(sample_x_hr,
size=[1, 1],
image_path=sample_hr_dir,
inv_type='127')
iu.save_images(sample_x_lr,
size=[1, 1],
image_path=sample_lr_dir,
inv_type='127')
learning_rate = 1e-4
for epoch in range(start_epoch, train_step['train_epochs']):
pointer = 0
for i in range(ds.n_images // train_step['batch_size']):
start = pointer
pointer += train_step['batch_size']
if pointer > ds.n_images: # if 1 epoch is ended
# Shuffle training DataSet
perm = np.arange(ds.n_images)
np.random.shuffle(perm)
hr, lr = hr[perm], lr[perm]
start = 0
pointer = train_step['batch_size']
end = pointer
batch_x_hr, batch_x_lr = hr[start:end], lr[start:end]
# reshape
batch_x_hr = np.reshape(batch_x_hr, [train_step['batch_size']] + model.hr_image_shape[1:])
batch_x_lr = np.reshape(batch_x_lr, [train_step['batch_size']] + model.lr_image_shape[1:])
# Update Only G network
d_loss, g_loss, g_init_loss = 0., 0., 0.
if epoch <= train_step['init_epochs']:
_, g_init_loss = s.run([model.g_init_op, model.g_cnt_loss],
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
# Update G/D network
else:
_, d_loss = s.run([model.d_op, model.d_loss],
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
_, g_loss = s.run([model.g_op, model.g_loss],
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
if i % train_step['logging_interval'] == 0:
# Print loss
if epoch <= train_step['init_epochs']:
print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
" MSE loss : {:.8f}".format(g_init_loss))
else:
print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
" D loss : {:.8f}".format(d_loss),
" G loss : {:.8f}".format(g_loss))
summary = s.run(model.merged,
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
# Summary saver
model.writer.add_summary(summary, global_step)
# Training G model with sample image and noise
sample_x_lr = np.reshape(sample_x_lr, [model.sample_num] + model.lr_image_shape[1:])
samples = s.run(model.g,
feed_dict={
model.x_lr: sample_x_lr,
model.lr: learning_rate,
})
# Export image generated by model G
# sample_image_height = model.output_height
# sample_image_width = model.output_width
sample_dir = results['output'] + 'train_{:08d}.png'.format(global_step)
# Generated image save
iu.save_images(samples,
size=[1, 1],
image_path=sample_dir,
inv_type='127')
# Model save
model.saver.save(s, results['model'], global_step)
# Learning Rate update
if epoch and epoch % model.lr_update_epoch == 0:
learning_rate *= model.lr_decay_rate
learning_rate = max(learning_rate, model.lr_low_boundary)
global_step += 1
end_time = time.time() - start_time # Clocking end
# Elapsed time
print("[+] Elapsed time {:.8f}s".format(end_time))
# Close tf.Session
s.close()
if __name__ == '__main__':
main()
6. 运行结果(生成图像)
初始图像LR:
训练过程图像:0-55000
生成高精度HR图像:
完整代码
import tensorflow as tf
import vgg19
import sys
sys.path.append('../')
import tfutil as t
tf.set_random_seed(777) # reproducibility
class SRGAN:
def __init__(self, s, batch_size=16, height=384, width=384, channel=3,
sample_num=1 * 1, sample_size=1,
df_dim=64, gf_dim=64, lr=1e-4, use_vgg19=True):
""" Super-Resolution GAN Class
# General Settings
:param s: TF Session
:param batch_size: training batch size, default 16
:param height: input image height, default 384
:param width: input image width, default 384
:param channel: input image channel, default 3 (RGB)
- in case of DIV2K-HR, image size is 384x384x3(HWC).
# Output Settings
:param sample_num: the number of output images, default 1
:param sample_size: sample image size, default 1
# For CNN model
:param df_dim: discriminator filter, default 64
:param gf_dim: generator filter, default 64
# Training Option
:param lr: learning rate, default 1e-4
:param use_vgg19: using pre-trained vgg19 bottle-neck features, default False
"""
self.s = s
self.batch_size = batch_size
self.height = height
self.width = width
self.channel = channel
self.lr_image_shape = [None, self.height // 4, self.width // 4, self.channel]
self.hr_image_shape = [None, self.height, self.width, self.channel]
self.vgg_image_shape = [224, 224, 3]
self.sample_num = sample_num
self.sample_size = sample_size
self.df_dim = df_dim
self.gf_dim = gf_dim
self.beta1 = 0.9
self.beta2 = 0.999
self.lr_decay_rate = 1e-1
self.lr_low_boundary = 1e-5
self.lr_update_step = 1e5
self.lr_update_epoch = 1000
self.vgg_mean = [103.939, 116.779, 123.68]
# pre-defined
self.d_real = 0.
self.d_fake = 0.
self.d_loss = 0.
self.g_adv_loss = 0.
self.g_cnt_loss = 0.
self.g_loss = 0.
self.psnr = 0.
self.use_vgg19 = use_vgg19
self.vgg19 = None
self.g = None
self.adv_scaling = 1e-3
self.cnt_scaling = 1. / 12.75 # 6e-3
self.d_op = None
self.g_op = None
self.g_init_op = None
self.merged = None
self.writer = None
self.saver = None
# Placeholders
self.x_hr = tf.placeholder(tf.float32, shape=self.hr_image_shape, name="x-image-hr") # (-1, 384, 384, 3)
self.x_lr = tf.placeholder(tf.float32, shape=self.lr_image_shape, name="x-image-lr") # (-1, 96, 96, 3)
self.lr = tf.placeholder(tf.float32, name='lr')
self.build_srgan() # build SRGAN model
def discriminator(self, x, reuse=None):
"""
# Following a network architecture referred in the paper
:param x: Input images (-1, 384, 384, 3)
:param reuse: re-usability
:return: HR (High Resolution) or SR (Super Resolution) images
"""
with tf.variable_scope("discriminator", reuse=reuse):
x = t.conv2d(x, self.df_dim, 3, 1, name='n64s1-1')
x = tf.nn.leaky_relu(x)
strides = [2, 1]
filters = [1, 2, 2, 4, 4, 8, 8]
for i, f in enumerate(filters):
x = t.conv2d(x, f=f, k=3, s=strides[i % 2], name='n%ds%d-%d' % (f, strides[i % 2], i + 1))
x = t.batch_norm(x, name='n%d-bn-%d' % (f, i + 1))
x = tf.nn.leaky_relu(x)
x = tf.layers.flatten(x) # (-1, 96 * 96 * 64)
x = t.dense(x, 1024, name='disc-fc-1')
x = tf.nn.leaky_relu(x)
x = t.dense(x, 1, name='disc-fc-2')
# x = tf.nn.sigmoid(x)
return x
def generator(self, x, reuse=None, is_train=True):
"""
:param x: LR (Low Resolution) images, (-1, 96, 96, 3)
:param reuse: scope re-usability
:param is_train: is trainable, default True
:return: SR (Super Resolution) images, (-1, 384, 384, 3)
"""
with tf.variable_scope("generator", reuse=reuse):
def residual_block(x, f, name="", _is_train=True):
with tf.variable_scope(name):
shortcut = tf.identity(x, name='n64s1-shortcut')
x = t.conv2d(x, f, 3, 1, name="n64s1-1")
x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-1")
x = t.prelu(x, reuse=reuse, name='n64s1-prelu-1')
x = t.conv2d(x, f, 3, 1, name="n64s1-2")
x = t.batch_norm(x, is_train=_is_train, name="n64s1-bn-2")
x = tf.add(x, shortcut)
return x
x = t.conv2d(x, self.gf_dim, 9, 1, name='n64s1-1')
x = t.prelu(x, name='n64s1-prelu-1')
skip_conn = tf.identity(x, name='skip_connection')
# B residual blocks
for i in range(1, 17): # (1, 9)
x = residual_block(x, self.gf_dim, name='b-residual_block_%d' % i, _is_train=is_train)
x = t.conv2d(x, self.gf_dim, 3, 1, name='n64s1-3')
x = t.batch_norm(x, is_train=is_train, name='n64s1-bn-3')
x = tf.add(x, skip_conn)
# sub-pixel conv2d blocks
for i in range(1, 3):
x = t.conv2d(x, self.gf_dim * 4, 3, 1, name='n256s1-%d' % (i + 2))
x = t.sub_pixel_conv2d(x, f=None, s=2)
x = t.prelu(x, name='n256s1-prelu-%d' % i)
x = t.conv2d(x, self.channel, 9, 1, name='n3s1') # (-1, 384, 384, 3)
x = tf.nn.tanh(x)
return x
def build_vgg19(self, x, reuse=None):
with tf.variable_scope("vgg19", reuse=reuse):
# image re-scaling
x = tf.cast((x + 1) / 2, dtype=tf.float32) # [-1, 1] to [0, 1]
x = tf.cast(x * 255., dtype=tf.float32) # [0, 1] to [0, 255]
r, g, b = tf.split(x, 3, 3)
bgr = tf.concat([b - self.vgg_mean[0],
g - self.vgg_mean[1],
r - self.vgg_mean[2]], axis=3)
self.vgg19 = vgg19.VGG19(bgr)
net = self.vgg19.vgg19_net['conv5_4']
return net # last layer
def build_srgan(self):
# Generator
self.g = self.generator(self.x_lr)
# Discriminator
d_real = self.discriminator(self.x_hr)
d_fake = self.discriminator(self.g, reuse=True)
# Losses
# d_real_loss = -tf.reduce_mean(t.safe_log(d_real))
# d_fake_loss = -tf.reduce_mean(t.safe_log(1. - d_fake))
d_real_loss = t.sce_loss(d_real, tf.ones_like(d_real))
d_fake_loss = t.sce_loss(d_fake, tf.zeros_like(d_fake))
self.d_loss = d_real_loss + d_fake_loss
if self.use_vgg19:
x_vgg_real = tf.image.resize_images(self.x_hr, size=self.vgg_image_shape[:2], align_corners=False)
x_vgg_fake = tf.image.resize_images(self.g, size=self.vgg_image_shape[:2], align_corners=False)
vgg_bottle_real = self.build_vgg19(x_vgg_real)
vgg_bottle_fake = self.build_vgg19(x_vgg_fake, reuse=True)
self.g_cnt_loss = self.cnt_scaling * t.mse_loss(vgg_bottle_fake, vgg_bottle_real, self.batch_size,
is_mean=True)
else:
self.g_cnt_loss = t.mse_loss(self.g, self.x_hr, self.batch_size, is_mean=True)
# self.g_adv_loss = self.adv_scaling * tf.reduce_mean(-1. * t.safe_log(d_fake))
self.g_adv_loss = self.adv_scaling * t.sce_loss(d_fake, tf.ones_like(d_fake))
self.g_loss = self.g_adv_loss + self.g_cnt_loss
def inverse_transform(img):
return (img + 1.) * 127.5
# calculate PSNR
g, x_hr = inverse_transform(self.g), inverse_transform(self.x_hr)
self.psnr = t.psnr_loss(g, x_hr, self.batch_size)
# Summary
tf.summary.scalar("loss/d_real_loss", d_real_loss)
tf.summary.scalar("loss/d_fake_loss", d_fake_loss)
tf.summary.scalar("loss/d_loss", self.d_loss)
tf.summary.scalar("loss/g_cnt_loss", self.g_cnt_loss)
tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss)
tf.summary.scalar("loss/g_loss", self.g_loss)
tf.summary.scalar("misc/psnr", self.psnr)
tf.summary.scalar("misc/lr", self.lr)
# Optimizer
t_vars = tf.trainable_variables()
d_params = [v for v in t_vars if v.name.startswith('d')]
g_params = [v for v in t_vars if v.name.startswith('g')]
self.d_op = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2).minimize(loss=self.d_loss,
var_list=d_params)
self.g_op = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_loss,
var_list=g_params)
# pre-train
self.g_init_op = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2).minimize(loss=self.g_cnt_loss,
var_list=g_params)
# Merge summary
self.merged = tf.summary.merge_all()
# Model saver
self.saver = tf.train.Saver(max_to_keep=2)
self.writer = tf.summary.FileWriter('./model/', self.s.graph)
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import tensorflow as tf
import numpy as np
import sys
import time
sys.path.append('../')
import image_utils as iu
from datasets import Div2KDataSet as DataSet
np.random.seed(1337)
results = {
'output': './gen_img/',
'model': './model/SRGAN-model.ckpt'
}
train_step = {
'batch_size': 16,
'init_epochs': 100,
'train_epochs': 1501,
'global_step': 200001,
'logging_interval': 100,
}
def main():
start_time = time.time() # Clocking start
# Div2K - Track 1: Bicubic downscaling - x4 DataSet load
"""
ds = DataSet(ds_path="/home/zero/hdd/DataSet/DIV2K/",
ds_name="X4",
use_save=True,
save_type="to_h5",
save_file_name="/home/zero/hdd/DataSet/DIV2K/DIV2K",
use_img_scale=True)
"""
ds = DataSet(ds_hr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-hr.h5",
ds_lr_path="/home/zero/hdd/DataSet/DIV2K/DIV2K-lr.h5",
use_img_scale=True)
hr, lr = ds.hr_images, ds.lr_images
print("[+] Loaded HR image ", hr.shape)
print("[+] Loaded LR image ", lr.shape)
# GPU configure
gpu_config = tf.GPUOptions(allow_growth=True)
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config)
with tf.Session(config=config) as s:
with tf.device("/gpu:1"): # Change
# SRGAN Model
model = SRGAN(s, batch_size=train_step['batch_size'],
use_vgg19=False)
# Initializing
s.run(tf.global_variables_initializer())
# Load model & Graph & Weights
ckpt = tf.train.get_checkpoint_state('./model/')
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
model.saver.restore(s, ckpt.model_checkpoint_path)
global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
print("[+] global step : %d" % global_step, " successfully loaded")
else:
global_step = 0
print('[-] No checkpoint file found')
start_epoch = global_step // (ds.n_images // train_step['batch_size'])
rnd = np.random.randint(0, ds.n_images)
sample_x_hr, sample_x_lr = hr[rnd], lr[rnd]
sample_x_hr, sample_x_lr = \
np.reshape(sample_x_hr, [1] + model.hr_image_shape[1:]), \
np.reshape(sample_x_lr, [1] + model.lr_image_shape[1:])
# Export real image
# valid_image_height = model.sample_size
# valid_image_width = model.sample_size
sample_hr_dir, sample_lr_dir = results['output'] + 'valid_hr.png', results['output'] + 'valid_lr.png'
# Generated image save
iu.save_images(sample_x_hr,
size=[1, 1],
image_path=sample_hr_dir,
inv_type='127')
iu.save_images(sample_x_lr,
size=[1, 1],
image_path=sample_lr_dir,
inv_type='127')
learning_rate = 1e-4
for epoch in range(start_epoch, train_step['train_epochs']):
pointer = 0
for i in range(ds.n_images // train_step['batch_size']):
start = pointer
pointer += train_step['batch_size']
if pointer > ds.n_images: # if 1 epoch is ended
# Shuffle training DataSet
perm = np.arange(ds.n_images)
np.random.shuffle(perm)
hr, lr = hr[perm], lr[perm]
start = 0
pointer = train_step['batch_size']
end = pointer
batch_x_hr, batch_x_lr = hr[start:end], lr[start:end]
# reshape
batch_x_hr = np.reshape(batch_x_hr, [train_step['batch_size']] + model.hr_image_shape[1:])
batch_x_lr = np.reshape(batch_x_lr, [train_step['batch_size']] + model.lr_image_shape[1:])
# Update Only G network
d_loss, g_loss, g_init_loss = 0., 0., 0.
if epoch <= train_step['init_epochs']:
_, g_init_loss = s.run([model.g_init_op, model.g_cnt_loss],
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
# Update G/D network
else:
_, d_loss = s.run([model.d_op, model.d_loss],
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
_, g_loss = s.run([model.g_op, model.g_loss],
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
if i % train_step['logging_interval'] == 0:
# Print loss
if epoch <= train_step['init_epochs']:
print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
" MSE loss : {:.8f}".format(g_init_loss))
else:
print("[+] Epoch %04d Step %08d => " % (epoch, global_step),
" D loss : {:.8f}".format(d_loss),
" G loss : {:.8f}".format(g_loss))
summary = s.run(model.merged,
feed_dict={
model.x_hr: batch_x_hr,
model.x_lr: batch_x_lr,
model.lr: learning_rate,
})
# Summary saver
model.writer.add_summary(summary, global_step)
# Training G model with sample image and noise
sample_x_lr = np.reshape(sample_x_lr, [model.sample_num] + model.lr_image_shape[1:])
samples = s.run(model.g,
feed_dict={
model.x_lr: sample_x_lr,
model.lr: learning_rate,
})
# Export image generated by model G
# sample_image_height = model.output_height
# sample_image_width = model.output_width
sample_dir = results['output'] + 'train_{:08d}.png'.format(global_step)
# Generated image save
iu.save_images(samples,
size=[1, 1],
image_path=sample_dir,
inv_type='127')
# Model save
model.saver.save(s, results['model'], global_step)
# Learning Rate update
if epoch and epoch % model.lr_update_epoch == 0:
learning_rate *= model.lr_decay_rate
learning_rate = max(learning_rate, model.lr_low_boundary)
global_step += 1
end_time = time.time() - start_time # Clocking end
# Elapsed time
print("[+] Elapsed time {:.8f}s".format(end_time))
# Close tf.Session
s.close()
if __name__ == '__main__':
main()