【代码阅读】WarpGAN: Automatic Caricature Generation
参考书籍:《Tensorflow 实战Google深度学习框架》
我觉得看一下第三章可以更清晰的了解tensorflow是怎么建立,训练一个神经网络的。
1. train.py
这份文件定义了主函数
def main(args):
初始化:
# Initalization for running
if config.save_model:
log_dir = utils.create_log_dir(config, config_file)
summary_writer = tf.summary.FileWriter(log_dir, network.graph)
if config.restore_model:
network.restore_model(config.restore_model, config.restore_scopes)
proc_func = lambda images: preprocess(images, config, True)
trainset.start_batch_queue(config.batch_size, proc_func=proc_func)
这里的config参数设置都来自于文件 WarpGAN\config\default.py
数据集读取初始化等操作来自于文件 WarpGAN\utils\dataset.py
主循环:
# Main Loop
print('\nStart Training\nname: {}\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.format(
config.name, config.num_epochs, config.epoch_size, config.batch_size))
global_step = 0
start_time = time.time()
for epoch in range(config.num_epochs):
if epoch == 0: test(network, config, log_dir, global_step)
# Training
for step in range(config.epoch_size):
# Prepare input
learning_rate = utils.get_updated_learning_rate(global_step, config)
batch = trainset.pop_batch_queue()
wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)
wl['lr'] = learning_rate
# Display
if step % config.summary_interval == 0:
duration = time.time() - start_time
start_time = time.time()
utils.display_info(epoch, step, duration, wl)
if config.save_model:
summary_writer.add_summary(sm, global_step=global_step)
wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)
这句话是重点,调用了网络的训练
2. warpgan.py
这个文件中定义了warpgan这个网络的计算图,前向传播以及损失函数。
训练神经网络的过程可以概括为下面这三个步骤:
1)定义神经网络的结构和前向传播的输出结果
2)定义损失函数(根据前向传播的输出结果计算出来的)以及反向传播优化的算法
3)生成会话(tf.Session()),并在训练数据上反复运行反向传播优化算法
def train(self, images_batch, labels_batch, switch_batch, learning_rate, keep_prob):
images_A = images_batch[~switch_batch]
images_B = images_batch[switch_batch]
labels_A = labels_batch[~switch_batch]
labels_B = labels_batch[switch_batch]
scales_A = np.ones((images_A.shape[0]))
scales_B = np.ones((images_B.shape[0]))
feed_dict = { self.images_A: images_A,
self.images_B: images_B,
self.labels_A: labels_A,
self.labels_B: labels_B,
self.scales_A: scales_A,
self.scales_B: scales_B,
self.learning_rate: learning_rate,
self.keep_prob: keep_prob,
self.phase_train: True,}
_, wl, sm = self.sess.run([self.train_op, self.watch_list, self.summary_op], feed_dict = feed_dict)
step = self.sess.run(self.global_step)
return wl, sm, step
train函数被上面的train.py调用,是生成会话这个步骤
其中的 self.train_op, self.watch_list, self.summary_op, self.global_step 分别是几个运算。
我们主要关注self.train_op这个运算(更新参数)
它在initialize 这个函数中定义,这个函数定义了前向传播和损失函数
def initialize(self, config, num_classes=None):
'''
Initialize the graph from scratch according to config.
'''
with self.graph.as_default():
with self.sess.as_default():
# Set up placeholders
h, w = config.image_size
channels = config.channels
self.images_A = tf.placeholder(tf.float32, shape=[None, h, w, channels], name='images_A')
self.images_B = tf.placeholder(tf.float32, shape=[None, h, w, channels], name='images_B')
self.labels_A = tf.placeholder(tf.int32, shape=[None], name='labels_A')
self.labels_B = tf.placeholder(tf.int32, shape=[None], name='labels_B')
self.scales_A = tf.placeholder(tf.float32, shape=[None], name='scales_A')
self.scales_B = tf.placeholder(tf.float32, shape=[None], name='scales_B')
self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.phase_train = tf.placeholder(tf.bool, name='phase_train')
self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32, name='global_step')
self.setup_network_model(config, num_classes)
# Build generator
encode_A, styles_A = self.encoder(self.images_A)
encode_B, styles_B = self.encoder(self.images_B)
deform_BA, render_BA, ldmark_pred, ldmark_diff = self.decoder(encode_B, self.scales_B, None)
render_AA = self.decoder(encode_A, self.scales_A, styles_A, texture_only=True)
render_BB = self.decoder(encode_B, self.scales_B, styles_B, texture_only=True)
self.styles_A = tf.identity(styles_A, name='styles_A')
self.styles_B = tf.identity(styles_B, name='styles_B')
self.deform_BA = tf.identity(deform_BA, name='deform_BA')
self.ldmark_pred = tf.identity(ldmark_pred, name='ldmark_pred')
self.ldmark_diff = tf.identity(ldmark_diff, name='ldmark_diff')
# Build discriminator for real images
patch_logits_A, logits_A = self.discriminator(self.images_A)
patch_logits_B, logits_B = self.discriminator(self.images_B)
patch_logits_BA, logits_BA = self.discriminator(deform_BA)
# Show images in TensorBoard
image_grid_A = tf.stack([self.images_A, render_AA], axis=1)[:1]
image_grid_B = tf.stack([self.images_B, render_BB], axis=1)[:1]
image_grid_BA = tf.stack([self.images_B, deform_BA], axis=1)[:1]
image_grid = tf.concat([image_grid_A, image_grid_B, image_grid_BA], axis=0)
image_grid = tf.reshape(image_grid, [-1] + list(self.images_A.shape[1:]))
image_grid = self.image_grid(image_grid, (3,2))
tf.summary.image('image_grid', image_grid)
# Build all losses
self.watch_list = {}
loss_list_G = []
loss_list_D = []
# Advesarial loss for deform_BA
loss_D, loss_G = self.cls_adv_loss(logits_A, logits_B, logits_BA,
self.labels_A, self.labels_B, self.labels_B, num_classes)
loss_D, loss_G = config.coef_adv*loss_D, config.coef_adv*loss_G
self.watch_list['LDg'] = loss_D
self.watch_list['LGg'] = loss_G
loss_list_D.append(loss_D)
loss_list_G.append(loss_G)
# Patch Advesarial loss for deform_BA
loss_D, loss_G = self.patch_adv_loss(patch_logits_A, patch_logits_B, patch_logits_BA)
loss_D, loss_G = config.coef_patch_adv*loss_D, config.coef_patch_adv*loss_G
self.watch_list['LDp'] = loss_D
self.watch_list['LGp'] = loss_G
loss_list_D.append(loss_D)
loss_list_G.append(loss_G)
# Identity Mapping (Reconstruction) loss
loss_idt_A = tf.reduce_mean(tf.abs(render_AA - self.images_A), name='idt_loss_A')
loss_idt_A = config.coef_idt * loss_idt_A
loss_idt_B = tf.reduce_mean(tf.abs(render_BB - self.images_B), name='idt_loss_B')
loss_idt_B = config.coef_idt * loss_idt_B
self.watch_list['idtA'] = loss_idt_A
self.watch_list['idtB'] = loss_idt_B
loss_list_G.append(loss_idt_A+loss_idt_B)
# Collect all losses
reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), name='reg_loss')
self.watch_list['reg_loss'] = reg_loss
loss_list_G.append(reg_loss)
loss_list_D.append(reg_loss)
loss_G = tf.add_n(loss_list_G, name='loss_G')
grads_G = tf.gradients(loss_G, self.G_vars)
loss_D = tf.add_n(loss_list_D, name='loss_D')
grads_D = tf.gradients(loss_D, self.D_vars)
# Training Operaters
train_ops = []
opt_G = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5, beta2=0.9)
opt_D = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5, beta2=0.9)
apply_G_gradient_op = opt_G.apply_gradients(list(zip(grads_G, self.G_vars)))
apply_D_gradient_op = opt_D.apply_gradients(list(zip(grads_D, self.D_vars)))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_ops.extend([apply_G_gradient_op, apply_D_gradient_op] + update_ops)
train_ops.append(tf.assign_add(self.global_step, 1))
self.train_op = tf.group(*train_ops)
# Collect TF summary
for k,v in self.watch_list.items():
tf.summary.scalar('losses/' + k, v)
tf.summary.scalar('learning_rate', self.learning_rate)
self.summary_op = tf.summary.merge_all()
# Initialize variables
self.sess.run(tf.local_variables_initializer())
self.sess.run(tf.global_variables_initializer())
self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=99)
根据这里所定义的前向传播,我们可以画出下面这张前向传播图:
和论文中的网络结构图相结合,我们可以画出下面这张详细过程图:
3. WarpGAN\models\default.py
这个文件中定义了encoder,decoder,discriminator三个网络的详细结构供 warpgan.py调用
为了找到特征点是如何训练出来的,我主要看了属于decoder的warpcontoller这个生成脸部特征点的子网络
with tf.variable_scope('WarpController'):
print('-- WarpController')
net = encoded
warp_input = tf.identity(images_rendered, name='warp_input')
net = slim.flatten(net)
net = slim.fully_connected(net, 128, scope='fc1')
print('module fc1 shape:', [dim.value for dim in net.shape])
num_ldmark = 16
# Predict the control points
ldmark_mean = (np.random.normal(0,50, (num_ldmark,2)) + np.array([[0.5*h,0.5*w]])).flatten()
ldmark_mean = tf.Variable(ldmark_mean.astype(np.float32), name='ldmark_mean')
print('ldmark_mean shape:', [dim.value for dim in ldmark_mean.shape])
ldmark_pred = slim.fully_connected(net, num_ldmark*2,
weights_initializer=tf.truncated_normal_initializer(stddev=1.0),
normalizer_fn=None, activation_fn=None, biases_initializer=None, scope='fc_ldmark')
ldmark_pred = ldmark_pred + ldmark_mean
print('ldmark_pred shape:', [dim.value for dim in ldmark_pred.shape])
ldmark_pred = tf.identity(ldmark_pred, name='ldmark_pred')
# Predict the displacements
ldmark_diff = slim.fully_connected(net, num_ldmark*2,
normalizer_fn=None, activation_fn=None, scope='fc_diff')
print('ldmark_diff shape:', [dim.value for dim in ldmark_diff.shape])
ldmark_diff = tf.identity(ldmark_diff, name='ldmark_diff')
ldmark_diff = tf.identity(tf.reshape(scales,[-1,1]) * ldmark_diff, name='ldmark_diff_scaled')
src_pts = tf.reshape(ldmark_pred, [-1, num_ldmark ,2])
dst_pts = tf.reshape(ldmark_pred + ldmark_diff, [-1, num_ldmark, 2])
diff_norm = tf.reduce_mean(tf.norm(src_pts-dst_pts, axis=[1,2]))
# tf.summary.scalar('diff_norm', diff_norm)
# tf.summary.scalar('mark', ldmark_pred[0,0])
images_transformed, dense_flow = sparse_image_warp(warp_input, src_pts, dst_pts,
regularization_weight = 1e-6, num_boundary_points=0)
dense_flow = tf.identity(dense_flow, name='dense_flow')
我的理解如下:
1)特征点:ldmark_mean+ldmark_pred, 每次迭代ldmark_mean均为由中心加上某个随机数生成的随机点,ldmark_pred在网络中更新后(网络的输入是经过encode的图片)再加上ldmark_mean作为ldmark_pred的一步更新结果。
2)特征点移动距离:由encode后的图片经过全连接网络得到
论文中关于变形和风格迁移之间的联系解释如下(翻译):
不同于其他视觉风格的转换任务,本文将照片转换成漫画既涉及到纹理差异,也涉及几何坐标转换。纹理是在夸大局部细粒度特征,如皱纹的深度;而几何变形允许夸大整体特征,如面部形状。传统风格的传输网络旨在使用解码器网络从特征空间重构图像。由于解码器是一组非线性局部滤波器,其本质上受空间变化的影响,当输入域和输出域之间存在较大的几何差异时,解码器的图像质量较差,信息丢失严重。另一方面,基于翘曲的方法受限于无法更改内容和细粒度细节。因此,风格转换和变形模块都是我们的学习框架中必不可少的部分。
如下图所示,没有任何一个模块,生成器将无法缩小照片和漫画之间的差距,而生成器和鉴别器之间对抗的平衡将被破坏,从而导致崩溃的结果。
因此这篇文章中的风格转变和变形必须是同时进行的,不能只单单找到特征点改变形状。
我最近看到另一篇文章,CariGANs,它和这篇文章一样,也是根据特征点对脸部进行变形,我觉得之后还可以继续看一下。