欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

GAN的Loss的比较研究(3)——Wasserstein Loss理解(1)

程序员文章站 2022-03-09 13:06:01
...

前两篇文章讨论了传统GAN的Loss,该Loss有些不足的地方,导致了GAN的训练十分困难,表现为:1、模式坍塌,即生成样本的多样性不足;2、不稳定,收敛不了。Martin Arjovsky在《Towards principled methods for training generative adversarial networks》、《Wasserstein GAN》文章中,对传统Loss造成训练困难的原因进行了讨论:因为真实样本的概率分布Pr与生成器生成的样本概率分布Pg的支撑集不同,又由于两者的流型(Manifold)的维度皆小于样本空间的维度,因而两者的流型基本上是不可能完全对齐的,因而即便有少量相交的点,它们在两个概率流型上的测度为0,可忽略,因而可以将两个概率的流型看成是可分离的,因而若是一个最优的判别器去判断则一定可以百分百将这两个流型分开,即无论我们的生成器如何努力皆获得不了分类误差的信息,这便是GAN训练困难的重要原因,有一篇博文(《令人拍案叫绝的Wasserstein GAN》)对上述两篇文章做了深入浅出的解释,总结一下是:
用KL Divergence和JS Divergence作为两个概率的差异的衡量,最关键的问题是若两个概率的支撑集不重叠,就无法让那个参数化的、可移动的概率分布慢慢地移动过来,以拟合目标分布。
于是文章提出一种新的Loss定义,即Wasserstein Distance,它可以作为两个概率分布的距离衡量指标,其定义如下:

W(Pr,Pg)=infγΠ(Pr,Pg)E(x,y)γ[xy](1)

W(Pr,Pg)是概率分布Pr,Pg的距离,它是两个在同一空间上(即维度相同)的随机变量x,y之差的范数均值的下确界。假设PrPr都是Rd上的概率分布,则两个空间相乘,构成Rd×d概率空间,在此空间中,找出所有在Rd边界分布为Pr和在另外一边Rd边界分布为Pr的所有分布,它们构成一个集合,即Π(Pr,Pg)。在此集合中,我们任意抽取一个元素,即γ,它是一个在Rd×d上的分布,由它抽样出的样本,皆在Rd×d上,这些样本点分别投影在两个互补的Rd边界上,就是x和y,取其差的范数求平均,即可得到一个值W。最后,在不同γ分布下,求出W的下界,即是Wasserstein Distance。
该距离的优点是:
1、若两个概率分布完全重合时,W(Pr,Pg)=0
2、是对称的,KL Divergence不对称
3、即使两个分布的支撑不相交,亦可以衡量,并在满足一定条件下可微,具备了后向传输的能力。
但式(1)不能直接计算,因为我们无法得到所有Π(Pr,Pg),因而需要另辟蹊径,《Wasserstein GAN》(完整的证明在《Optimal Transport: Old and New》中,“最优转换”有空真应该好好学习一下。)指出(1)式可以转换成:
W(Pr,Pg)=supfL1ExPr[f(x)]ExPg[f(x)](2)

f(x)是函数集fL1中的一个函数。fL1表示满足1-Lipschitz条件的函数集。
(Lipschitz条件是一个比通常连续更强的光滑性条件。直觉上,Lipschitz连续函数限制了函数改变的速度,符合利Lipschitz条件的函数的斜率,必小于一个称为Lipschitz常数的实数)
满足K-Lipschitz条件的函数集:
1、xRdf(x)R,是实值函数
2、f(x)在定义域内都有
|f(x1)f(x2)|K×|x1x2|(3)

(3)式中”||”表示度量空间的距离,此处可以是一次、二次范数。(2)式中要求的是1-Lipschitz条件,也可以用K-Lipschitz条件代替,此时有:
KW(Pr,Pg)=supfLKExPr[f(x)]ExPg[f(x)](4)

(2)式与(4)式只有系数的不同,因此在使用时,可以选定一个K值。(4)式如何理解呢?见(4)式右边其中一项——ExPr[f(x)]

  1. x是一个来自于随机分布Pr的样本
  2. f(x)表示样本x的一个函数变换,该函数满足Lipschitz条件
  3. 取多个样本函数的平均值,即:1Mi=1Mf(xi),xiPr

(4)式要求得到上确界,上确界的具体函数形式我们不知道,但我们可以用神经网络来逼近它,这是判别器(Discriminator)的作用,也就是Discriminator网络充当了f(x)的角色,因此(4)等价于:

KW(Pr,Pg)=maxθ(ExPr[f(x)]ExPg[f(x)])=maxθ(1Mi=1,xiPrMfθ(xi)1Nj=1,xjPgNfθ(xj))(5)

(5)式表示两个分布(真实图像分布与生成图像分布)的距离,在确定K值后,就可以用它来定义判别器(Discriminator)的Loss_D,(损失函数要求最小值,(5)式需要求反)即有:
Loss_D=minθ(1Mi=1,xiPrMfθ(xi)+1Nj=1,xjPgNfθ(xj))(6)θ=argminθ(Loss_D)(7)

Discriminator希望此距离越大越好,通过调节判别器参数(θ)达到目标;生成器(Generator)则不然,它希望此距离越小越好,(5)式定义的距离对于生成器有如下关系:
KW(Pr,Pg)=maxθ(1Mi=1,xiPrMfθ(xi)1Nj=1,zjPzNfθ(gω(zj)))(8)

(5)式中的随机分布Pg是生成图像的分布,它是由随机分布zPz经过生成器映(gω(z)))而得到的一个分布,生成器只能调节其自身参数ω,因而其它与之无关的项均不用考虑,由此定义生成器的损失函数是Loss_G为:
Loss_G=1Nj=1,zjPzNfθ(gω(zj))(9)ω=argminω(Loss_G)(10)

以下是一段摘自https://github.com/Streamrock/PyTorch-GAN/blob/master/implementations/wgan/wgan.py的代码:

for epoch in range(opt.n_epochs):

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

剩下的就是要保证fθ(x)满足K-Lipschitz条件,《Wasserstein GAN》做了一个简单地处理,因为判别器是由神经网络构成的,因此对每层的线性算子中参数做了一个夹逼,限制其取值范围,就可以实现。如上面代码的这个部分:

# Clip weights of discriminator
for p in discriminator.parameters():
    p.data.clamp_(-opt.clip_value, opt.clip_value)

此取值范围既不能太大,又不能太小,到底取多少合适呢?《Wasserstein GAN》没做讨论,留给后来有心人。

相关标签: GAN