GAN生成对抗网络
程序员文章站
2023-02-21 23:33:16
对于GAN摸索了一段时间,有一点心的,就是要注意使用普通的网络作为生成器和判别器(例如:全连接网络)需要注意使用BatchNormalization,进行批量归一化,不然很难出现好的结果。还有生成器的最后一层需要使用tanh()函数,推荐吧,也可以使用sigmoid,二者在这里的区别,可以自己找找。这是GAN的pytorch版本的实现。导入相关库import torchimport torchvisionimport torch.nn as nnimport torch.nn.fun.....
- 对于GAN摸索了一段时间,有一点心的,就是要注意使用普通的网络作为生成器和判别器(例如:全连接网络)需要注意使用BatchNormalization,进行批量归一化,不然很难出现好的结果。还有生成器的最后一层需要使用tanh()函数,推荐吧,也可以使用sigmoid,二者在这里的区别,可以自己找找。
- 这是GAN的pytorch版本的实现。
-
导入相关库
import torch import torchvision import torch.nn as nn import torch.nn.functional as F import matplotlib.pylab as plt from matplotlib import animation from IPython.display import HTML
-
设置用到的一些常量
BATCH_SIZE = 100 IMG_CHANNELS = 1 NUM_Z = 100 NUM_GENERATOR_FEATURES = 64 NUM_DISCRIMINATOR_FEATURES = 64 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE) # INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
-
加载数据集(MNIST10)数据集
transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) # ds = torchvision.datasets.cifar.CIFAR10(root="data", train=True, transform=transform, download=True) ds = torchvision.datasets.mnist.MNIST(root="data", train=True, transform=transform, download=True) ds_loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
-
查看数据
img_batch, lab_batch = next(iter(ds_loader)) img_batch.shape, lab_batch.shape
-
绘制数据集图像
plt.figure(figsize=(8, 8), dpi=80) plt.imshow(torchvision.utils.make_grid(img_batch, nrow=10, padding=2, pad_value=1, normalize=True).permute(1, 2, 0)) plt.tight_layout() plt.axis("off")
-
定义生成器和判别器
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # (o - 1) * s - 2 * p + w self.main = nn.Sequential( # 100 x 1 x 1 --> 512 x 4 x 4 nn.ConvTranspose2d(NUM_Z, NUM_GENERATOR_FEATURES * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 8), nn.ReLU(True), # 512 x 4 x 4 --> 512 x 8 x 8 nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 8, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4), nn.ReLU(True), # 512 x 8 x 8 --> 512 x 16 x 16 nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2), nn.ReLU(True), # 512 x 16 x 16 --> 512 x 14 x 14 nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 1, 1, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1), nn.ReLU(True), # 512 x 14 x 14 --> 512 x 28 x 28 nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 1, IMG_CHANNELS, 2, 2, 0, bias=False), nn.Sigmoid(), ) def forward(self, x): return self.main(x) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( # 1 x 28 x 28 --> 256 x 14 x 14 nn.Conv2d(IMG_CHANNELS, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4), nn.LeakyReLU(0.2, inplace=True), # 256 x 14 x 14 --> 128 x 7 x 7 nn.Conv2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2), nn.LeakyReLU(0.2, inplace=True), # 128 x 7 x 7 --> 64 x 3 x 3 nn.Conv2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 4, 2, 1, bias=False), nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1), nn.LeakyReLU(0.2, inplace=True), # 64 x 3 x 3 --> 1 x 1 x 1 nn.Conv2d(NUM_GENERATOR_FEATURES * 1, 1, 3, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): return self.main(x).view(-1)
-
测试定义的模型
noise = torch.randn(BATCH_SIZE, NUM_Z, 1, 1) generator = Generator() fake_img = generator(noise) discriminator = Discriminator() discriminator(fake_img)
-
网络参数初始化函数
# custom weights initialization called on netG and netD def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0)
-
定义生成器和判别器对象,优化器,损失函数和评估标准
generator = Generator().to(DEVICE).apply(weights_init) discriminator = Discriminator().to(DEVICE).apply(weights_init) optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) loss_fn = nn.BCELoss() metrics_fn = lambda y_true, y_pred: torch.mean((y_true == torch.where(y_pred >=0.5, torch.tensor(1., device=DEVICE), torch.tensor(0., device=DEVICE))).to(torch.float32))
-
定义训练步骤(重点)
def train_step(inputs, labels): labels = labels.to(torch.float32) inputs_g = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE) # inputs_g = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE) outputs_g = generator(inputs_g) # fix generator, unfix discriminator for parameter in generator.parameters(): parameter.require_grad = False for parameter in discriminator.parameters(): parameter.require_grad = True optimizer_d.zero_grad() # real image labels = torch.ones_like(labels) outputs = discriminator(inputs) loss_real = loss_fn(outputs, labels) metrics_real = metrics_fn(labels, outputs) loss_real.backward() # fake image labels = torch.zeros_like(labels) outputs = discriminator(outputs_g.detach()) # 这里有一个detach() loss_fake = loss_fn(outputs, labels) metrics_fake = metrics_fn(labels, outputs) loss_fake.backward() loss_d = (loss_real + loss_fake) / 2 metrics_d = (metrics_real + metrics_fake) / 2 # loss_d.backward() optimizer_d.step() # unfix generator, fix discriminator for parameter in generator.parameters(): parameter.require_grad = True for parameter in discriminator.parameters(): parameter.require_grad = False optimizer_g.zero_grad() labels = torch.ones_like(labels) outputs = discriminator(outputs_g) loss_g = loss_fn(outputs, labels) metrics_g = metrics_fn(labels, outputs) loss_g.backward() optimizer_g.step() return loss_d.item(), metrics_d.item(), loss_g.item(), metrics_g.item()
-
测试定义的训练步骤
train_step(img_batch.to(DEVICE), lab_batch.to(DEVICE))
-
定义训练循环
epochs = 8 loss_d_list, metrics_d_list, loss_g_list, metrics_g_list = [], [], [], [] grid_img_list = [] for epoch in range(epochs): loss_d_batch = metrics_d_batch = loss_g_batch = metrics_g_batch = .0 num_batch = 0 for img_batch, lab_batch in ds_loader: img_batch = img_batch.to(DEVICE) lab_batch = lab_batch.to(DEVICE) loss_d, metrics_d, loss_g, metrics_g = train_step(img_batch, torch.ones_like(lab_batch)) num_batch += 1 loss_d_batch, metrics_d_batch = loss_d_batch + loss_d, metrics_d_batch + metrics_d loss_g_batch, metrics_g_batch = loss_g_batch + loss_g, metrics_g_batch + metrics_g loss_d_batch, metrics_d_batch = loss_d_batch / num_batch, metrics_d_batch / num_batch loss_g_batch, metrics_g_batch = loss_g_batch / num_batch, metrics_g_batch / num_batch loss_d_list.append(loss_d_batch) metrics_d_list.append(metrics_d_batch) loss_g_list.append(loss_g_batch) metrics_g_list.append(metrics_g_batch) print("[%d/%d] loss_discriminator: %.2f, metrics_distriminator: %.2f, loss_generator: %.2f, metrics_generator: %.2f" % ( epoch, epochs, loss_d_batch, metrics_d_batch, loss_g_batch, metrics_g_batch)) with torch.no_grad(): outputs_g = generator(INPUTS_G) outputs_d = discriminator(outputs_g) grid_img_list.append(torchvision.utils.make_grid(outputs_g.cpu(), nrow=10, normalize=True, pad_value=1)) plt.figure(figsize=(20, 2), dpi=80) for i, (img, lab) in enumerate(zip(outputs_g[:16], outputs_d[:16])): plt.subplot(1, 16, i+1) plt.imshow(img.cpu().permute(1, 2, 0), cmap=plt.cm.binary) plt.title("%.2f" % lab.cpu().item()) plt.axis("off") plt.tight_layout() plt.show()
-
绘制损失值和评估指标
plt.figure(figsize=(12, 4), dpi=80) plt.subplot(1, 2, 1) plt.plot(loss_d_list, label="discriminator_loss") plt.plot(loss_g_list, label="generator_loss") plt.title("Loss of discriminator and generator") plt.xlabel("epochs") plt.ylabel("loss") plt.legend() plt.subplot(1, 2, 2) plt.plot(metrics_d_list, label="discriminator_metrics") plt.plot(metrics_g_list, label="generator_metrics") plt.title("Metrics of discriminator and generator") plt.xlabel("epochs") plt.ylabel("metrics") plt.legend() plt.show()
-
绘制动态的GAN图像生成过程
fig = plt.figure(figsize=(10, 10), dpi=80) plt.axis("off") imgs = [[plt.imshow(np.transpose(img, (1, 2, 0)), animated=True)] for img in grid_img_list] ani = animation.ArtistAnimation(fig, imgs, interval=1000, repeat_delay=1000, blit=True) HTML(ani.to_jshtml())
-
绘制真实图片和GAN生成图片
plt.figure(figsize=(20, 10), dpi=80) plt.subplot(1, 2, 1) plt.title("real digits image") plt.imshow(torchvision.utils.make_grid(img_batch.cpu(), nrow=10, normalize=True, pad_value=1).permute(1, 2, 0)) plt.axis("off") plt.subplot(1, 2, 2) plt.title("fake digits image") plt.imshow(np.transpose(grid_img_list[-1], (1, 2, 0))) plt.axis("off")
本文地址:https://blog.csdn.net/bash_winner/article/details/113998997