欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【代码阅读】WarpGAN: Automatic Caricature Generation

程序员文章站 2024-03-20 23:15:10
...

代码链接

参考书籍:《Tensorflow 实战Google深度学习框架》

我觉得看一下第三章可以更清晰的了解tensorflow是怎么建立,训练一个神经网络的。

1. train.py

这份文件定义了主函数

def main(args):

初始化:

# Initalization for running
    if config.save_model:
        log_dir = utils.create_log_dir(config, config_file)
        summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    proc_func = lambda images: preprocess(images, config, True)
    trainset.start_batch_queue(config.batch_size, proc_func=proc_func)

这里的config参数设置都来自于文件 WarpGAN\config\default.py

数据集读取初始化等操作来自于文件 WarpGAN\utils\dataset.py

主循环:

# Main Loop
    print('\nStart Training\nname: {}\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.format(
            config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()
    for epoch in range(config.num_epochs):

        if epoch == 0: test(network, config, log_dir, global_step)

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)

            wl['lr'] = learning_rate

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                if config.save_model:
                    summary_writer.add_summary(sm, global_step=global_step)

wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)

这句话是重点,调用了网络的训练

2. warpgan.py

这个文件中定义了warpgan这个网络的计算图,前向传播以及损失函数。

训练神经网络的过程可以概括为下面这三个步骤:

1)定义神经网络的结构和前向传播的输出结果

2)定义损失函数(根据前向传播的输出结果计算出来的)以及反向传播优化的算法

3)生成会话(tf.Session()),并在训练数据上反复运行反向传播优化算法

    def train(self, images_batch, labels_batch, switch_batch, learning_rate, keep_prob):
        images_A = images_batch[~switch_batch]
        images_B = images_batch[switch_batch]
        labels_A = labels_batch[~switch_batch]
        labels_B = labels_batch[switch_batch]
        scales_A = np.ones((images_A.shape[0]))
        scales_B = np.ones((images_B.shape[0]))
        feed_dict = {   self.images_A: images_A,
                        self.images_B: images_B,
                        self.labels_A: labels_A,
                        self.labels_B: labels_B,
                        self.scales_A: scales_A,
                        self.scales_B: scales_B,
                        self.learning_rate: learning_rate,
                        self.keep_prob: keep_prob,
                        self.phase_train: True,}
        _, wl, sm = self.sess.run([self.train_op, self.watch_list, self.summary_op], feed_dict = feed_dict)

        step = self.sess.run(self.global_step)

        return wl, sm, step

train函数被上面的train.py调用,是生成会话这个步骤

其中的 self.train_op, self.watch_list, self.summary_op, self.global_step 分别是几个运算。

我们主要关注self.train_op这个运算(更新参数)

它在initialize 这个函数中定义,这个函数定义了前向传播和损失函数

    def initialize(self, config, num_classes=None):
        '''
            Initialize the graph from scratch according to config.
        '''
        with self.graph.as_default():
            with self.sess.as_default():
                # Set up placeholders
                h, w = config.image_size
                channels = config.channels
                self.images_A = tf.placeholder(tf.float32, shape=[None, h, w, channels], name='images_A')
                self.images_B = tf.placeholder(tf.float32, shape=[None, h, w, channels], name='images_B')
                self.labels_A = tf.placeholder(tf.int32, shape=[None], name='labels_A')
                self.labels_B = tf.placeholder(tf.int32, shape=[None], name='labels_B')
                self.scales_A = tf.placeholder(tf.float32, shape=[None], name='scales_A')
                self.scales_B = tf.placeholder(tf.float32, shape=[None], name='scales_B')

                self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
                self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
                self.phase_train = tf.placeholder(tf.bool, name='phase_train')
                self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32, name='global_step')

                self.setup_network_model(config, num_classes)

                # Build generator
                encode_A, styles_A = self.encoder(self.images_A)
                encode_B, styles_B = self.encoder(self.images_B)

                deform_BA, render_BA, ldmark_pred, ldmark_diff = self.decoder(encode_B, self.scales_B, None)
                render_AA = self.decoder(encode_A, self.scales_A, styles_A, texture_only=True)
                render_BB = self.decoder(encode_B, self.scales_B, styles_B, texture_only=True)

                self.styles_A = tf.identity(styles_A, name='styles_A')
                self.styles_B = tf.identity(styles_B, name='styles_B')
                self.deform_BA = tf.identity(deform_BA, name='deform_BA')
                self.ldmark_pred = tf.identity(ldmark_pred, name='ldmark_pred')
                self.ldmark_diff = tf.identity(ldmark_diff, name='ldmark_diff')


                # Build discriminator for real images
                patch_logits_A, logits_A = self.discriminator(self.images_A)
                patch_logits_B, logits_B = self.discriminator(self.images_B)
                patch_logits_BA, logits_BA = self.discriminator(deform_BA)                          

                # Show images in TensorBoard
                image_grid_A = tf.stack([self.images_A, render_AA], axis=1)[:1]
                image_grid_B = tf.stack([self.images_B, render_BB], axis=1)[:1]
                image_grid_BA = tf.stack([self.images_B, deform_BA], axis=1)[:1]
                image_grid = tf.concat([image_grid_A, image_grid_B, image_grid_BA], axis=0)
                image_grid = tf.reshape(image_grid, [-1] + list(self.images_A.shape[1:]))
                image_grid = self.image_grid(image_grid, (3,2))
                tf.summary.image('image_grid', image_grid)


                # Build all losses
                self.watch_list = {}
                loss_list_G  = []
                loss_list_D  = []
               
                # Advesarial loss for deform_BA
                loss_D, loss_G = self.cls_adv_loss(logits_A, logits_B, logits_BA,
                    self.labels_A, self.labels_B, self.labels_B, num_classes)
                loss_D, loss_G = config.coef_adv*loss_D, config.coef_adv*loss_G

                self.watch_list['LDg'] = loss_D
                self.watch_list['LGg'] = loss_G
                loss_list_D.append(loss_D)
                loss_list_G.append(loss_G)

                # Patch Advesarial loss for deform_BA
                loss_D, loss_G = self.patch_adv_loss(patch_logits_A, patch_logits_B, patch_logits_BA)
                loss_D, loss_G = config.coef_patch_adv*loss_D, config.coef_patch_adv*loss_G

                self.watch_list['LDp'] = loss_D
                self.watch_list['LGp'] = loss_G
                loss_list_D.append(loss_D)
                loss_list_G.append(loss_G)

                # Identity Mapping (Reconstruction) loss
                loss_idt_A = tf.reduce_mean(tf.abs(render_AA - self.images_A), name='idt_loss_A')
                loss_idt_A = config.coef_idt * loss_idt_A

                loss_idt_B = tf.reduce_mean(tf.abs(render_BB - self.images_B), name='idt_loss_B')
                loss_idt_B = config.coef_idt * loss_idt_B

                self.watch_list['idtA'] = loss_idt_A
                self.watch_list['idtB'] = loss_idt_B
                loss_list_G.append(loss_idt_A+loss_idt_B)


                # Collect all losses
                reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), name='reg_loss')
                self.watch_list['reg_loss'] = reg_loss
                loss_list_G.append(reg_loss)
                loss_list_D.append(reg_loss)


                loss_G = tf.add_n(loss_list_G, name='loss_G')
                grads_G = tf.gradients(loss_G, self.G_vars)

                loss_D = tf.add_n(loss_list_D, name='loss_D')
                grads_D = tf.gradients(loss_D, self.D_vars)

                # Training Operaters
                train_ops = []

                opt_G = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5, beta2=0.9)
                opt_D = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5, beta2=0.9)
                apply_G_gradient_op = opt_G.apply_gradients(list(zip(grads_G, self.G_vars)))
                apply_D_gradient_op = opt_D.apply_gradients(list(zip(grads_D, self.D_vars)))

                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                train_ops.extend([apply_G_gradient_op, apply_D_gradient_op] + update_ops)

                train_ops.append(tf.assign_add(self.global_step, 1))
                self.train_op = tf.group(*train_ops)

                # Collect TF summary
                for k,v in self.watch_list.items():
                    tf.summary.scalar('losses/' + k, v)
                tf.summary.scalar('learning_rate', self.learning_rate)
                self.summary_op = tf.summary.merge_all()

                # Initialize variables
                self.sess.run(tf.local_variables_initializer())
                self.sess.run(tf.global_variables_initializer())
                self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=99)
 

根据这里所定义的前向传播,我们可以画出下面这张前向传播图:

【代码阅读】WarpGAN: Automatic Caricature Generation

和论文中的网络结构图相结合,我们可以画出下面这张详细过程图:

【代码阅读】WarpGAN: Automatic Caricature Generation

3. WarpGAN\models\default.py

这个文件中定义了encoder,decoder,discriminator三个网络的详细结构供 warpgan.py调用

为了找到特征点是如何训练出来的,我主要看了属于decoder的warpcontoller这个生成脸部特征点的子网络

                    with tf.variable_scope('WarpController'):

                        print('-- WarpController')

                        net = encoded
                        warp_input = tf.identity(images_rendered, name='warp_input')

                        net = slim.flatten(net)

                        net = slim.fully_connected(net, 128, scope='fc1')
                        print('module fc1 shape:', [dim.value for dim in net.shape])

                        num_ldmark = 16

                        # Predict the control points
                        ldmark_mean = (np.random.normal(0,50, (num_ldmark,2)) + np.array([[0.5*h,0.5*w]])).flatten()
                        ldmark_mean = tf.Variable(ldmark_mean.astype(np.float32), name='ldmark_mean')
                        print('ldmark_mean shape:', [dim.value for dim in ldmark_mean.shape])

                        ldmark_pred = slim.fully_connected(net, num_ldmark*2, 
                            weights_initializer=tf.truncated_normal_initializer(stddev=1.0),
                            normalizer_fn=None, activation_fn=None, biases_initializer=None, scope='fc_ldmark')
                        ldmark_pred = ldmark_pred + ldmark_mean
                        print('ldmark_pred shape:', [dim.value for dim in ldmark_pred.shape])
                        ldmark_pred = tf.identity(ldmark_pred, name='ldmark_pred')
                 

                        # Predict the displacements
                        ldmark_diff = slim.fully_connected(net, num_ldmark*2, 
                            normalizer_fn=None,  activation_fn=None, scope='fc_diff')
                        print('ldmark_diff shape:', [dim.value for dim in ldmark_diff.shape])
                        ldmark_diff = tf.identity(ldmark_diff, name='ldmark_diff')
                        ldmark_diff = tf.identity(tf.reshape(scales,[-1,1]) * ldmark_diff, name='ldmark_diff_scaled')



                        src_pts = tf.reshape(ldmark_pred, [-1, num_ldmark ,2])
                        dst_pts = tf.reshape(ldmark_pred + ldmark_diff, [-1, num_ldmark, 2])

                        diff_norm = tf.reduce_mean(tf.norm(src_pts-dst_pts, axis=[1,2]))
                        # tf.summary.scalar('diff_norm', diff_norm)
                        # tf.summary.scalar('mark', ldmark_pred[0,0])

                        images_transformed, dense_flow = sparse_image_warp(warp_input, src_pts, dst_pts,
                                regularization_weight = 1e-6, num_boundary_points=0)
                        dense_flow = tf.identity(dense_flow, name='dense_flow')

我的理解如下:

1)特征点:ldmark_mean+ldmark_pred, 每次迭代ldmark_mean均为由中心加上某个随机数生成的随机点,ldmark_pred在网络中更新后(网络的输入是经过encode的图片)再加上ldmark_mean作为ldmark_pred的一步更新结果。

2)特征点移动距离:由encode后的图片经过全连接网络得到

论文中关于变形和风格迁移之间的联系解释如下(翻译):

不同于其他视觉风格的转换任务,本文将照片转换成漫画既涉及到纹理差异,也涉及几何坐标转换。纹理是在夸大局部细粒度特征,如皱纹的深度;而几何变形允许夸大整体特征,如面部形状。传统风格的传输网络旨在使用解码器网络从特征空间重构图像。由于解码器是一组非线性局部滤波器,其本质上受空间变化的影响,当输入域和输出域之间存在较大的几何差异时,解码器的图像质量较差,信息丢失严重。另一方面,基于翘曲的方法受限于无法更改内容和细粒度细节。因此,风格转换和变形模块都是我们的学习框架中必不可少的部分。

如下图所示,没有任何一个模块,生成器将无法缩小照片和漫画之间的差距,而生成器和鉴别器之间对抗的平衡将被破坏,从而导致崩溃的结果。

【代码阅读】WarpGAN: Automatic Caricature Generation

因此这篇文章中的风格转变和变形必须是同时进行的,不能只单单找到特征点改变形状。

我最近看到另一篇文章,CariGANs,它和这篇文章一样,也是根据特征点对脸部进行变形,我觉得之后还可以继续看一下。