欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Generative Adversarial Networks

程序员文章站 2023-12-24 23:26:15
...

Minimax objective function

样本来源共有两部分:xPdatazP(z)

两个网络:判别器Dθd,生成器Gθg

首先从判别器的视角来看,分别考虑两部分数据:

对于第一部分的real样本xPdata,只和判别器Dθd打交道,判别器对于real样本预测为real的概率为Dθd(x),最大化这个概率(取对数),有

maxθdExPdata logDθd(x)

对于第二部分样本zP(z),送入生成器后得到fake样本Gθg(z)

因为在第一部分样本中,我们已经有maxθd,所以我们需要为判别器再构造出一个概率最大化的式子,于是想到最大化判别器预测fake样本为fake的概率,即1Dθd(Gθg(z)),最大化这个概率(取对数),有

maxθdEzP(z)log(1Dθd(Gθg(z)))

最终,将这两个maxθd整合起来,可得判别器的优化目标
maxθd[ExPdata logDθd(x)+EzP(z)log(1Dθd(Gθg(z)))]

然后从生成器的视角来看,生成器与判别器的目标恰好相反,于是将maxθd改为minθg,并且注意到ExPdata logDθd(x)这一项与生成器无关,可以去掉,最终可得

minθgEzP(z)log(1Dθd(Gθg(z)))

综上所述,整合两部分样本后,GAN的训练目标为
minθgmaxθd[ExPdata logDθd(x)+EzP(z)log(1Dθd(Gθg(z)))]

Training

在训练时需要把判别器和生成器的训练目标拆开(好不容易整合起来了又要分开?!)

训练D的目标

maxθd[ExPdata logDθd(x)+EzP(z)log(1Dθd(Gθg(z)))]

添加负号,将max变为min,变为一种loss,我们希望该loss越小越好(省略期望符号E

minθd[logDθd(x)log(1Dθd(Gθg(z)))]

D_loss = tf.reduce_mean( -tf.log( D_real ) - tf.log( 1 - D_fake ) )

我们稍微使用一点技巧,就可以转化为交叉熵

minθd[1_logDθd(x)(10_)log(1Dθd(Gθg(z)))]

上式可以理解为对real样本Dθd(x)赋予标签1_,对fake样本Dθd(Gθg(z))赋予标签0_,因此上式等价于

minθd[LBCE(1,Dθd(x))+LBCE(0,Dθd(Gθg(z)))]

D_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=D_logit_real, labels=label_one ) )
D_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=D_logit_fake, labels=label_zero ) )
D_loss = D_loss_real + D_loss_fake

训练G的目标

minθgEzP(z)log(1Dθd(Gθg(z)))

为了避免“flat gradient”,我们将“最小化判别器预测fake样本为fake的概率”转换为“最大化判别器预测fake样本为real的概率”

maxθgEzP(z)logDθd(Gθg(z))

添加负号,将max变为min,变为一种loss,我们希望该loss越小越好(省略期望符号E

minθglogDθd(Gθg(z))

G_loss = tf.reduce_mean( -tf.log( D_fake ) )

同样地,稍微使用一点技巧,就可以转化为交叉熵

minθg1_logDθd(Gθg(z))

上式可以理解为对fake样本Dθd(Gθg(z))赋予标签1_(原本fake样本应该赋予标签0_,但为了fool判别器,赋予相反的标签1_),因此上式等价于

minθdLBCE(1,Dθd(Gθg(z)))

G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=D_logit_fake, labels=label_one ) )

Question:这是否就是How to Train a GAN? Tips and tricks to make GANs work中所说的“Flip labels when training generator: real = fake, fake = real”?
目前找到的Answer:Trick 2 explanation #15,从vijayvee的回答中,个人认为Answer是Yes

关于“flat gradient”的解释

对于fake样本Gθg(z),可以分为如下两类

Dθd(Gθg(z))落在区间[0,1]中靠近0的一侧是那些生成质量差的fake样本

Dθd(Gθg(z))落在区间[0, 1]中靠近1的一侧是那些生成质量好的fake样本

Generative Adversarial Networks

Generative Adversarial Networks

上图来自cs231n,横轴为Dθd(Gθg(z)),纵轴为两个优化目标(蓝绿曲线)的值

对于蓝线,由于左边比较“flat”(斜率小),右边斜率大,因此那些生成质量好的fake样本的梯度会占据主导,这显然不是我们想要的

对于绿线,由于左边斜率大,右边斜率小,因此那些生成质量差的fake样本的梯度会占据主导,于是参数θg更新后,会着重提升那些生成质量差的样本的生成质量,这是我们想要的

上一篇:

下一篇: