Generative Adversarial Networks
Minimax objective function
样本来源共有两部分:,
两个网络:判别器,生成器
首先从判别器的视角来看,分别考虑两部分数据:
对于第一部分的real样本,只和判别器打交道,判别器对于real样本预测为real的概率为,最大化这个概率(取对数),有
对于第二部分样本,送入生成器后得到fake样本
因为在第一部分样本中,我们已经有,所以我们需要为判别器再构造出一个概率最大化的式子,于是想到最大化判别器预测fake样本为fake的概率,即,最大化这个概率(取对数),有
最终,将这两个整合起来,可得判别器的优化目标
然后从生成器的视角来看,生成器与判别器的目标恰好相反,于是将改为,并且注意到这一项与生成器无关,可以去掉,最终可得
综上所述,整合两部分样本后,GAN的训练目标为
Training
在训练时需要把判别器和生成器的训练目标拆开(好不容易整合起来了又要分开?!)
训练D的目标
添加负号,将变为,变为一种loss,我们希望该loss越小越好(省略期望符号)
D_loss = tf.reduce_mean( -tf.log( D_real ) - tf.log( 1 - D_fake ) )
我们稍微使用一点技巧,就可以转化为交叉熵
上式可以理解为对real样本赋予标签,对fake样本赋予标签,因此上式等价于
D_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=D_logit_real, labels=label_one ) )
D_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=D_logit_fake, labels=label_zero ) )
D_loss = D_loss_real + D_loss_fake
训练G的目标
为了避免“flat gradient”,我们将“最小化判别器预测fake样本为fake的概率”转换为“最大化判别器预测fake样本为real的概率”
添加负号,将变为,变为一种loss,我们希望该loss越小越好(省略期望符号)
G_loss = tf.reduce_mean( -tf.log( D_fake ) )
同样地,稍微使用一点技巧,就可以转化为交叉熵
上式可以理解为对fake样本赋予标签(原本fake样本应该赋予标签,但为了fool判别器,赋予相反的标签),因此上式等价于
G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=D_logit_fake, labels=label_one ) )
Question:这是否就是How to Train a GAN? Tips and tricks to make GANs work中所说的“Flip labels when training generator: real = fake, fake = real”?
目前找到的Answer:Trick 2 explanation #15,从vijayvee的回答中,个人认为Answer是Yes
关于“flat gradient”的解释
对于fake样本,可以分为如下两类
落在区间中靠近0的一侧是那些生成质量差的fake样本
落在区间[0, 1]中靠近1的一侧是那些生成质量好的fake样本
上图来自cs231n,横轴为,纵轴为两个优化目标(蓝绿曲线)的值
对于蓝线,由于左边比较“flat”(斜率小),右边斜率大,因此那些生成质量好的fake样本的梯度会占据主导,这显然不是我们想要的
对于绿线,由于左边斜率大,右边斜率小,因此那些生成质量差的fake样本的梯度会占据主导,于是参数更新后,会着重提升那些生成质量差的样本的生成质量,这是我们想要的
推荐阅读
-
Generative Adversarial Networks
-
SRAGN: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
-
MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis笔记
-
An introduction to Generative Adversarial Networks (with code in TensorFlow)
-
Decoupled Learning for Conditional Adversarial Networks
-
Generative Adversarial Networks
-
Feature Pyramid Networks for Object Detection 阅读笔记
-
目标检测 Feature Pyramid Networks for Object Detection(FPN)论文笔记
-
详解Docker-compose networks 的例子
-
论文笔记:SlowFast Networks for Video Recognition