BEGAN: Boundary Equilibrium Generative Adversarial Networks的理解
程序员文章站
2022-03-09 13:33:19
...
BEGAN: Boundary Equilibrium Generative Adversarial Networks的理解
这是一篇2017年5月上传到arXiv上的文章,作者是David Berthelot,来自Google。Boundary Equilibrium译作“边界均衡”,文章创新的地方主要有以下几个地方:
- 应用auto-encoder实现Discriminator
- Discriminator的Loss_D由输入原图(input_img)与Decoder恢复的输出图(recover_img)之间的逐点error构成
因而将产生两个Loss_D,分别为真图判别损失Loss_D_real,以及伪图判别损失Loss_D_fake。 - Loss_D可看成是随机的分布,由real_img所形成的Loss_D_Real分布与由Generator生成的假图(fake_img)所形成的Loss_D_Fake分布,出现了两个分布,用Wasserstein Distance(简称WD)来衡量这两个分布的距离。Discriminator的目标是尽量拉开这两个分布的距离,而Generator的目标是缩小这两个分布的距离——GAN的基本思想。
- 引入了一个均衡的概念来调节Discriminator训练时的两个目标的比重:目标1,是提高auto-encoder的重构能力,即auto-encoder恢复输入input_img的能力;目标2,提高D的分辨真伪的能力。该均衡控制量是可以变动的,就像是电路中的反馈环,构成了反馈比例控制(Proportional Control)迭代机制。
本文是以WD的出发点来解释和构造GAN的,以下是Wasserstein Distance的定义:
WD本来就是用来衡量两个分布的距离的,知乎上有一篇文章讲得很详细:https://www.zhihu.com/question/39872326?sort=created
在BEGAN中,和是两个分布,代表由real_img在Discriminator上生成的Loss_D,即Loss_D_real,而代表fake_img在Discriminator上生成的Loss_D,即Loss_D_fake。便是衡量这两个分布的距离。
(2)式右边是求1次范数均值的下确界。是服从的随机样本,同理,是服从的随机样本,它们的联合分布服从,此中有一个约束条件,即是联合分布服的边沿分布必须是和。的所有可能形式构成一个概率空间,因此是的一个元素。在中取最小值的那个联合分布是所求的目标分布,它的期望就是所求距离。
作为Discriminator希望此距离越大越好,但最优联合分布的形式是未知的,因而直接求十分困难,因而需要用可变下界来渐近之,通过Jensen不等式有:
其中和分别是和的均值。于是可得的下界,有:
将(1)代入有:
要尽量增加距离,只有两种情况:
选(b)较合理,因为当D训练好时,对于真图这边的Error,我们是希望误差越小越好的,即趋向0。因而,(4)变形为
。尽量提升下界,即求的最大值,但常见的ML后向传递计算搜索的是Loss的最小值,因而在求Loss_D时需要对(7)求反,如下:
当经过理想的训练过程后,D应该分辨不出真伪,即:,因而有:
但由于在训练过程中(9)式两边并不匹配,左边会比右边衰减得快,因为Generator的生成过程收敛速度较慢,因而在(8)式右端第二项中添加一个可变的系数对它进行调整,平衡减式两端数值,该系数就是所谓的均衡(Equilibrium)——。的调整是一个迭代过程,如同电路的反馈环路,的迭代关系如下:
(10)式中是一个超级参数,取值可以是0.001。(8)式经过均衡的调整,变为了(12)式,这样的目的是让(12)式的前后两项不要相差太大,起到一个制衡(Trade off)的作用。
以下是用pytorch实现的一次训练迭代过程:
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs.detach())
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake
d_loss.backward()
optimizer_D.step()
#----------------
# Update weights
#----------------
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
# Update weight term for fake samples
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
还有一篇BEGAN的翻译,可以以看看:https://blog.csdn.net/m0_37561765/article/details/77512692
本文的参考:
1、代码
2、文章