欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

BEGAN: Boundary Equilibrium Generative Adversarial Networks的理解

程序员文章站 2022-03-09 13:33:19
...

BEGAN: Boundary Equilibrium Generative Adversarial Networks的理解

这是一篇2017年5月上传到arXiv上的文章,作者是David Berthelot,来自Google。Boundary Equilibrium译作“边界均衡”,文章创新的地方主要有以下几个地方:

  • 应用auto-encoder实现Discriminator
  • Discriminator的Loss_D由输入原图(input_img)与Decoder恢复的输出图(recover_img)之间的逐点error构成
    L(v)=|vD(v)|(1)

    因而将产生两个Loss_D,分别为真图判别损失Loss_D_real,以及伪图判别损失Loss_D_fake。
  • Loss_D可看成是随机的分布,由real_img所形成的Loss_D_Real分布与由Generator生成的假图(fake_img)所形成的Loss_D_Fake分布,出现了两个分布,用Wasserstein Distance(简称WD)来衡量这两个分布的距离。Discriminator的目标是尽量拉开这两个分布的距离,而Generator的目标是缩小这两个分布的距离——GAN的基本思想。
  • 引入了一个均衡的概念来调节Discriminator训练时的两个目标的比重:目标1,是提高auto-encoder的重构能力,即auto-encoder恢复输入input_img的能力;目标2,提高D的分辨真伪的能力。该均衡控制量是可以变动的,就像是电路中的反馈环,构成了反馈比例控制(Proportional Control)迭代机制。

本文是以WD的出发点来解释和构造GAN的,以下是Wasserstein Distance的定义:

W(u1,u2)=infγΓ(u1,u2)E(x1,x2)γ[|x1x2|](2)

WD本来就是用来衡量两个分布的距离的,知乎上有一篇文章讲得很详细:https://www.zhihu.com/question/39872326?sort=created
在BEGAN中,u1u2是两个分布,u1代表由real_img在Discriminator上生成的Loss_D,即Loss_D_real,而u2代表fake_img在Discriminator上生成的Loss_D,即Loss_D_fake。W(u1,u2)便是衡量这两个分布的距离。
(2)式右边是求1次范数均值的下确界。x1是服从u1的随机样本,同理,x2是服从u2的随机样本,它们的联合分布服从γ,此中有一个约束条件,即是联合分布服γ的边沿分布必须是u1u2γ的所有可能形式构成一个概率空间Γ(u1,u2),因此γΓ的一个元素。在Γ(u1,u2)中取最小值的那个联合分布γ是所求的目标分布,它的期望E(x1,x2)γ[|x1x2|]就是所求距离。
作为Discriminator希望此距离越大越好,但最优联合分布γ的形式是未知的,因而直接求十分困难,因而需要用可变下界来渐近之,通过Jensen不等式有:
infE[|x1x2|]inf|E[x1x2]|=|m1m2|(3)

其中m1m2分别是x1x2的均值。于是可得W(u1,u2)的下界,有:
W(u1,u2)|m1m2|(4)

将(1)代入有:
{m1=Evu1|vD(v)|m2=EG(zG)|G(zG)D(G(zG))|(5)

要尽量增加距离,只有两种情况:
{m1m20(a)or{m10m2(b)(6)

选(b)较合理,因为当D训练好时,对于真图这边的Error,我们是希望误差越小越好的,即m1趋向0。因而,(4)变形为
W(u1,u2)m2m1(7)
。尽量提升下界,即求m2m1的最大值,但常见的ML后向传递计算搜索的是Loss的最小值,因而在求Loss_D时需要对(7)求反,如下:
LD=m1m2=E(L(x))E(L(G(zD)))(8)

当经过理想的训练过程后,D应该分辨不出真伪,即:W(u1,u2)0,因而有:
E(L(x))=E(L(G(zD)))(9)

但由于在训练过程中(9)式两边并不匹配,左边会比右边衰减得快,因为Generator的生成过程收敛速度较慢,因而在(8)式右端第二项中添加一个可变的系数对它进行调整,平衡减式两端数值,该系数就是所谓的均衡(Equilibrium)——ktkt的调整是一个迭代过程,如同电路的反馈环路,kt的迭代关系如下:
kt+1=kt+λk(γL(x)L(G(zG)))(10)γ=E(L(G(zG)))E(L(x))(11)LD=E(L(x))ktE(L(G(zD)))(12)kt is clamped to [0,1]

(10)式中λk是一个超级参数,取值可以是0.001。(8)式经过均衡的调整,变为了(12)式,这样的目的是让(12)式的前后两项不要相差太大,起到一个制衡(Trade off)的作用。
以下是用pytorch实现的一次训练迭代过程:

# ---------------------
#  Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Measure discriminator's ability to classify real from generated samples
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs.detach())

d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake

d_loss.backward()
optimizer_D.step()

#----------------
# Update weights
#----------------

diff = torch.mean(gamma * d_loss_real - d_loss_fake)

# Update weight term for fake samples
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]

还有一篇BEGAN的翻译,可以以看看:https://blog.csdn.net/m0_37561765/article/details/77512692
本文的参考:
1、代码
2、文章


相关标签: GAN