20.生成对抗网络
文章目录
本课程来自深度之眼deepshare.net,部分截图来自课程视频。
Chloe H. 提供:GAN训练的tip,
https://chloes-dl.com/2019/11/19/tricks-and-tips-for-training-a-gan/
生成对抗网络(GAN)是什么?
GAN(Generative Adversarial Nets):生成对抗网络——一种可以生成特定分布数据的模型
文献:Generative Adversarial Nets. Ian Goodfellow. 2014
堃哥说:
Adversarial training is the coolest thing since sliced bread.I’ ve listed a bunch of relevant papers in a previous answer. Expect more impressive results with this technique in the coming years.
下面是用GAN生成的64张人脸。有些比较畸形。。。
inference
1.输入:
用高斯分布随机的采样一些噪声
2.构建模型,加载参数:
这里注意,模型inference时只用到了Generator(生成器),不需要Discriminator(判别器)
3.inference,把输入放到Generator中就可以生成虚假数据。
fake_data=net_g(fixed_noise). detach(). cpu()
GAN网络结构
以下三个图片(每个图片都是讲的GAN结构)分别来自:
《Recent Progress on Generative Adversarial Networks(GANs):A Survey》
《How Generative Adversarial Networks and Its Variants Work:An Overview of GAN》
《Generative Adversarial Networks_A Survey and Taxonomy》
G代表生成器,D代表判别器,z是输入向量,输入向量通过生成器后,得到一个生成的结果,如果是人脸图片生成,这个的G(z)就是一个图片tensor,然后结合训练数据x,通过判别器给出图片是真还是假(D是二分类网络。)
如何训练GAN?
训练目的
1.对于D:对真样本输出高概率
2.对于G:输出使D会给出高概率的数据
GAN的训练模式与监督学习训练模式不一样的地方:需要注意的是,监督学习中损失函数的目标是让模型的输出值尽量的逼近真实值;在GAN中输出值不是逼近真实值,而是使得输出值的分布接近真实值的分布。
下面看具体步骤,二次元警告。。。。(李宏毅的笔记里面也有相应内容)
step1:训练D
输入:真实数据加G生成的假数据
输出:二分类概率
上图中是更新一次D的过程
step2:训练G
输入:随机噪声z
输出:分类概率——D(G(z))
上图中输出如果是0.13,那么差异为1-0.13,我们的目标是D输出的目标概率是越高越好,最好就是1,这里只有0.13,说明还不够好,需要继续训练G。
然后回到step1继续循环,知道满足收敛条件。
下面对GAN论文中对算法的文字进行一些解释
1.整个算法是一个大的for循环,可以根据图中的最长的横线分为两个部分,上面部分是训练判别器的,下面部分是训练生成器的。
2.先看训练判别器部分,这个部分是有一个for循环包围着的(1号箭头),这个是早期GAN的设置,意思是先要通过几次迭代训练几次判别器,后来经过实践证明,这里实际上是不需要的,只用训练一次就ok了,所以这里的循环次数k我们可以设置为1。
3.在训练判别器时,先分别从噪声和真实数据中进行采样,然后计算损失函数,注意在更新损失函数,用的是ascending梯度,原因分析:损失函数有两项,第一项是真实数据,我们希望这个的概率是越大越好(2号箭头),第二项是虚假数据,这个概率我们希望是越小越好,但是又有一个1-这一项,所以整个第二项也是越大越好(3号箭头),整体更新是变大的趋势,所以用的随机梯度上升法。
4.训练生成器部分,先从噪声中采样(这里的采样数据可以和上面部分的相同,也可以不同,感觉可以这样是因为我们在乎的是数据的分布,而不是具体的数据)
5.同理,生成器希望这个损失函数的值通过判别器判别后是真实数据(生成器要骗过判别器),所以这项是越大越好(4号箭头),则整体是越小越好(5号箭头)。因此在生成器部分用的是随机梯度下降法
6.可以看出,由于是对抗,在设计损失函数的时候,一个是梯度上升,一个是梯度下降;另外两个损失函数有一项是一样的,看图中绿线部分。
训练DCGAN实现人脸生成
《Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks》
Generator:卷积结构的模型
输入是100维的随机噪声,然后通过transpose的卷积生成一个64643的rgb图片
注意:输入在pytorch中用tensor表示为:(1,100,1,1)
第一个1 是batch,后面两个1是高和宽。
Discriminator:卷积结构的模型
老师很懒,直接把上面的结构旋转180度,输入是64643的rgb图像,不过输出是二分类。
DCGAN实现人脸生成
数据:CelebA人脸数据。
数据项目:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
不是用的原项目的人脸,而是用矫正过的。
22万人脸矫正图:
https://pan.baidu.com/s/1JDrl82vTjgFsmKQ0SPNtzA 密码:41g7 失效
矫正前:
人脸所在位置以及比例都不确定
矫正后,是通过五个人脸关键点(中心化)以及人脸所占比例进行了矫正:
构建transform的时候,需要把数据尺度变换到-1~ 1区间,因为随机采用的生成器的值也是这个区间,所以这里不追求0均值的分布,而是追求区间一致。
生成器的超参数初始化代码:
class Generator(nn.Module):
def __init__(self, nz=100, ngf=128, nc=3):#输入的维度是100,特征图数量是128,输出是3d张量
super(Generator, self).__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),#ngf * 8=1024,对应到结构图中的一个卷积模块
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
二分类的loss
# step3: loss
criterion=nn. BCELoss()
判别器和生成器的训练迭代过程的代码
############################
# (1) Update D network
###########################
net_d.zero_grad()
# create training data
real_img = data.to(device)
b_size = real_img.size(0)
real_label = torch.full((b_size,), real_idx, device=device)#real_idx是真实图片的lable
noise = torch.randn(b_size, nz, 1, 1, device=device)#输入是4d张量,第一个维度是batchsize,nz是100维
fake_img = net_g(noise)
fake_label = torch.full((b_size,), fake_idx, device=device)#fake_idx是假图片lable
# train D with real img
out_d_real = net_d(real_img)
loss_d_real = criterion(out_d_real.view(-1), real_label)
# train D with fake img
out_d_fake = net_d(fake_img.detach())
loss_d_fake = criterion(out_d_fake.view(-1), fake_label)
# backward
loss_d_real.backward()
loss_d_fake.backward()
loss_d = loss_d_real + loss_d_fake
# Update D
optimizerD.step()
# record probability
d_x = out_d_real.mean().item() # D(x)
d_g_z1 = out_d_fake.mean().item() # D(G(z1))
#以上完成一次判别器的更新
############################
# (2) Update G network
###########################
net_g.zero_grad()
label_for_train_g = real_label # 1
out_d_fake_2 = net_d(fake_img)
loss_g = criterion(out_d_fake_2.view(-1), label_for_train_g)
loss_g.backward()#只更新生成器,不改变判别器
optimizerG.step()#
# record probability
d_g_z2 = out_d_fake_2.mean().item() # D(G(z2))
# Output training stats
if i % 10 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(train_loader),
loss_d.item(), loss_g.item(), d_x, d_g_z1, d_g_z2))
# Save Losses for plotting later
G_losses.append(loss_g.item())
D_losses.append(loss_d.item())
训练过程中的注意事项:
1.特征图数量ngf是原始模型128,如果改为64,效果会变差,但是训练速度快一些
2.标签值的平滑处理,这里用的是1和0,可以平滑为:0.9和0.1
GAN的应用
https://medium.com/@jonathan_hui/gan-some-cool-applications-of-gans-4c9ecca35900(失效)
GAN的应用:《CycleGAN》
GAN的应用:《PixelDTGAN》
GAN的应用:《SRGAN》
GAN的应用:
Progressive GAN
GAN的应用:
《StackGAN》根据文本生成图片
GAN的应用:
《Context Encoders》
GAN的应用:
《Pix2Pix》
GAN的应用:
《ICGAN》
GAN推荐github:https://github.com/nightrome/really-awesome-gan