Decoupled Learning for Conditional Adversarial Networks
程序员文章站
2023-12-24 21:13:58
...
文章提出里在已有的ED+GAN的基础上,添见一个生成网络,即ED//GAN,网络结构如下,
上图中左边为传统的GAN网络,Enc+Dec相当于生成网络,D为判别网络,构造GAN损失函数,以及生成图片与输入的重构误差(L1损失函数,这种网络结构我们熟悉的有pix2pix,cyclegan.
上图中右边为本文提出的网络结构,即在ED+GAN的基础上,添加一个生成网络,相当于有两个生成网络.两个生成网络的目的是,两个生成网络分别可以学习图像的不同特征,例如,一个生成网络用于生成低频特征,另一个用于生成高频特征,
最后的结果为两个生成网络的和.判别网络用于判断最后生成的图片,以及目标输入图片的真假.损失函数同样为GAN loss,以及生成网络Enc+Dec网络生成图片与输入图片的重构误差,也就是希望生成图片的低频特征与输入图片尽量相似.
作者提供了github代码:https://github.com/ZZUTK/Decoupled-Learning-Conditional-GAN
代码包括在pix2pix,CAAE模型上的EN//GAN的结构,下面我以pix2pix的EN//GAN模型为例,分析代码.
代码中,Enc+Dec,生成网络generator,生成网络G,generator_p,的结构都与pix2pix的生成网络结构相同,
将输入图片输入两个网络,
self.const_B = self.generator(self.real_A)
self.res_B = self.generator_p(self.real_A)
将两者相加,得到最后的生成图片,
self.fake_B = self.const_B + self.res_B
与pix2pix一样,将生成图片,目标图片分别与输入图片串联,输入判别网络,判别网络结构也与pix2pix判别网络相同,
self.real_AB = tf.concat( [self.real_A, self.real_B],3)
self.fake_AB = tf.concat([self.real_A, self.fake_B],3)
self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False)
self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)
判别网络损失函数为GAN loss,
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))
生成网络损失函数,重构误差函数为Enc+Dec输出与输入图片的重构误差,以及GAN loss,
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))
self.const_loss = tf.reduce_mean(tf.abs(self.real_B - self.const_B))
生成效果对比,