DCGAN生成cifar10的图片
程序员文章站
2023-12-31 17:49:10
...
DCGAN(Deep Convolutional GAN):
DCGAN的贡献
- 提出了一类基于卷积神经网络的GANs,称为DCGAN,它在多数情况下训练是稳定的。
- 与其他非监督方法相比,DCGAN的discriminator提取到的图像特征更有效,更适合用于图像分类任务。
- 通过训练,DCGAN能学到有意义的 filters。
- DCGAN的generator能够保持latentspace到image的“连续性”。
DCGAN model
实际上,DCGAN是一类GAN的简称,满足以下设计要求(这些要求更像是一些tricks)的GAN网络都可以称为DCGAN模型。
- 采用全卷积神经网络。不使用空间池化,取而代之使用带步长的卷积层(strided convolution)。这么做能让网络自己学习更合适的空间下采样方法。PS:对于generator来说,要做上采样,采用的是分数步长的卷积(fractionally-stridedconvolution);对于discriminator来说,一般采用整数步长的卷积。
- 避免在卷积层之后使用全连接层。全连接层虽然增加了模型的稳定性,但也减缓了收敛速度。一般来说,generator的输入(噪声)采用均匀分布;discriminator的最后一个卷积层一般先摊平(flatten),然后接一个单节点的softmax。
- 除了generator的输出层和discriminator的输入层以外,其他层都是采用batch normalization。Batch normalization能确保每个节点的输入都是均值为0,方差为1。即使是初始化很差,也能保证网络中有足够强的梯度。
- 对于generator,输出层的**函数采用Tanh,其它层的**函数采用ReLU。对于discriminator,**函数采用leaky ReLU。
文件夹路径及配置:
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210321162042566.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RldmlsX0I=,size_16,color_FFFFFF,t_70) data:存放数据集 output:存放输出 weights:存放保存的模型参数gan.py:
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
if __name__ == '__main__':
cudnn.benchmark = True
# 将手动种子设置为常数可获得一致的输出
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
# 加载数据集
dataset = dset.CIFAR10(root="./data", download=True,
transform=transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc = 3
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
shuffle=True, num_workers=2)
# 检查是用GPU还是CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 可以获得的GPU数量
ngpu = 1
# 输入噪声尺寸
nz = 100
# 生成器过滤数量
ngf = 64
# 判别器过滤数量
ndf = 64
# 网络和网络上调用的自定义权重初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
netG = Generator(ngpu).to(device)
netG.apply(weights_init)
# 加载模型参数
# netG.load_state_dict(torch.load('weights/netG_epoch_24.pth'))
print(netG)
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1).squeeze(1)
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
# 加载模型参数
# netD.load_state_dict(torch.load('weights/netD_epoch_24.pth'))
print(netD)
criterion = nn.BCELoss()
# 设置优化器
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
fixed_noise = torch.randn(128, nz, 1, 1, device=device)
real_label = 1
fake_label = 0
niter = 25
g_loss = []
d_loss = []
for epoch in range(niter):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# 用数据训练
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label, device=device)
output = netD(real_cpu)
output = output.to(torch.float32)
label = label.to(torch.float32)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
# train with fake
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (
epoch, niter, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# save the output
if i % 100 == 0:
print('saving the output')
vutils.save_image(real_cpu, 'output/real_samples.png', normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.detach(), 'output/fake_samples_epoch_%03d.png' % (epoch), normalize=True)
#
torch.save(netG.state_dict(), 'weights/netG_epoch_%d.pth' % (epoch))
torch.save(netD.state_dict(), 'weights/netD_epoch_%d.pth' % (epoch))
接下来计划:
加载分类器模型保存的参数到gan中的判别器,首先使模型结构一直,其次控制模型输入输出一直。