欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

GAN生成对抗网络

程序员文章站 2023-02-21 23:33:16
对于GAN摸索了一段时间,有一点心的,就是要注意使用普通的网络作为生成器和判别器(例如:全连接网络)需要注意使用BatchNormalization,进行批量归一化,不然很难出现好的结果。还有生成器的最后一层需要使用tanh()函数,推荐吧,也可以使用sigmoid,二者在这里的区别,可以自己找找。这是GAN的pytorch版本的实现。导入相关库import torchimport torchvisionimport torch.nn as nnimport torch.nn.fun.....
  1. 对于GAN摸索了一段时间,有一点心的,就是要注意使用普通的网络作为生成器和判别器(例如:全连接网络)需要注意使用BatchNormalization,进行批量归一化,不然很难出现好的结果。还有生成器的最后一层需要使用tanh()函数,推荐吧,也可以使用sigmoid,二者在这里的区别,可以自己找找。
  2. 这是GAN的pytorch版本的实现。
  1. 导入相关库

    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    
    import matplotlib.pylab as plt
    from matplotlib import animation
    from IPython.display import HTML
    
  2. 设置用到的一些常量

    BATCH_SIZE = 100
    IMG_CHANNELS = 1
    NUM_Z = 100
    NUM_GENERATOR_FEATURES = 64
    NUM_DISCRIMINATOR_FEATURES = 64
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
    # INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
    
  3. 加载数据集(MNIST10)数据集

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    # ds = torchvision.datasets.cifar.CIFAR10(root="data", train=True, transform=transform,  download=True)
    ds = torchvision.datasets.mnist.MNIST(root="data", train=True, transform=transform, download=True)
    ds_loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    
  4. 查看数据

    img_batch, lab_batch = next(iter(ds_loader))
    img_batch.shape, lab_batch.shape
    
  5. 绘制数据集图像

    plt.figure(figsize=(8, 8), dpi=80)
    plt.imshow(torchvision.utils.make_grid(img_batch, nrow=10, padding=2, pad_value=1, normalize=True).permute(1, 2, 0))
    plt.tight_layout()
    plt.axis("off")
    
  6. 定义生成器和判别器

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            # (o - 1) * s - 2 * p + w
            self.main = nn.Sequential(
                # 100 x 1 x 1 --> 512 x 4 x 4
                nn.ConvTranspose2d(NUM_Z, NUM_GENERATOR_FEATURES * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 8),
                nn.ReLU(True),
                # 512 x 4 x 4 --> 512 x 8 x 8
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 8, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
                nn.ReLU(True),
                # 512 x 8 x 8 --> 512 x 16 x 16
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
                nn.ReLU(True),
                # 512 x 16 x 16 --> 512 x 14 x 14
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 1, 1, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
                nn.ReLU(True),
                # 512 x 14 x 14 --> 512 x 28 x 28
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 1, IMG_CHANNELS, 2, 2, 0, bias=False),
                nn.Sigmoid(),
            )
            
        def forward(self, x):
            return self.main(x)
    
        
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            
            self.main = nn.Sequential(
                # 1 x 28 x 28  --> 256 x 14 x 14 
                nn.Conv2d(IMG_CHANNELS, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # 256 x 14 x 14  --> 128 x 7 x 7
                nn.Conv2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
                nn.LeakyReLU(0.2, inplace=True),
                # 128 x 7 x 7 --> 64 x 3 x 3
                nn.Conv2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
                nn.LeakyReLU(0.2, inplace=True),
                # 64 x 3 x 3 --> 1 x 1 x 1
                nn.Conv2d(NUM_GENERATOR_FEATURES * 1, 1, 3, 1, 0, bias=False),
                nn.Sigmoid()
            )
        
        def forward(self, x):
            return self.main(x).view(-1)
    
  7. 测试定义的模型

    noise = torch.randn(BATCH_SIZE, NUM_Z, 1, 1)
    generator = Generator()
    fake_img = generator(noise)
    discriminator = Discriminator()
    discriminator(fake_img)
    
  8. 网络参数初始化函数

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
  9. 定义生成器和判别器对象,优化器,损失函数和评估标准

    generator = Generator().to(DEVICE).apply(weights_init)
    discriminator = Discriminator().to(DEVICE).apply(weights_init)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    loss_fn = nn.BCELoss()
    metrics_fn = lambda y_true, y_pred: torch.mean((y_true == torch.where(y_pred >=0.5, torch.tensor(1., device=DEVICE), torch.tensor(0., device=DEVICE))).to(torch.float32))
    
  10. 定义训练步骤(重点)

    def train_step(inputs, labels):
        labels = labels.to(torch.float32)
        inputs_g = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
        # inputs_g = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
        outputs_g = generator(inputs_g)
        
        # fix generator, unfix discriminator
        for parameter in generator.parameters():
            parameter.require_grad = False
        for parameter in discriminator.parameters():
            parameter.require_grad = True
        optimizer_d.zero_grad()
        
        # real image
        labels = torch.ones_like(labels)
        outputs = discriminator(inputs)
        loss_real = loss_fn(outputs, labels)
        metrics_real = metrics_fn(labels, outputs)
        loss_real.backward()
        
        # fake image
        labels = torch.zeros_like(labels)
        outputs = discriminator(outputs_g.detach())  # 这里有一个detach()
        loss_fake = loss_fn(outputs, labels)
        metrics_fake = metrics_fn(labels, outputs)
        loss_fake.backward()
        
        loss_d = (loss_real + loss_fake) / 2
        metrics_d = (metrics_real + metrics_fake) / 2
        # loss_d.backward()
        optimizer_d.step()
        
        
        # unfix generator, fix discriminator
        for parameter in generator.parameters():
            parameter.require_grad = True
        for parameter in discriminator.parameters():
            parameter.require_grad = False
        optimizer_g.zero_grad()
        
        labels = torch.ones_like(labels)
        outputs = discriminator(outputs_g)
        loss_g = loss_fn(outputs, labels)
        metrics_g = metrics_fn(labels, outputs)
        loss_g.backward()
        optimizer_g.step()
        
        return loss_d.item(), metrics_d.item(), loss_g.item(), metrics_g.item()
    
  11. 测试定义的训练步骤

    train_step(img_batch.to(DEVICE), lab_batch.to(DEVICE))
    
  12. 定义训练循环

    epochs = 8
    loss_d_list, metrics_d_list, loss_g_list, metrics_g_list = [], [], [], []
    grid_img_list = []
    
    for epoch in range(epochs):
        
        loss_d_batch = metrics_d_batch = loss_g_batch = metrics_g_batch = .0
        num_batch = 0
        for img_batch, lab_batch in ds_loader:
            img_batch = img_batch.to(DEVICE)
            lab_batch = lab_batch.to(DEVICE)
            loss_d, metrics_d, loss_g, metrics_g = train_step(img_batch, torch.ones_like(lab_batch))
            num_batch += 1
            loss_d_batch, metrics_d_batch = loss_d_batch + loss_d, metrics_d_batch + metrics_d
            loss_g_batch, metrics_g_batch = loss_g_batch + loss_g, metrics_g_batch + metrics_g
            
        loss_d_batch, metrics_d_batch = loss_d_batch / num_batch, metrics_d_batch / num_batch
        loss_g_batch, metrics_g_batch = loss_g_batch / num_batch, metrics_g_batch / num_batch
        
        loss_d_list.append(loss_d_batch)
        metrics_d_list.append(metrics_d_batch)
        loss_g_list.append(loss_g_batch)
        metrics_g_list.append(metrics_g_batch)
        
        print("[%d/%d] loss_discriminator: %.2f, metrics_distriminator: %.2f, loss_generator: %.2f, metrics_generator: %.2f" % (
            epoch, epochs, loss_d_batch, metrics_d_batch, loss_g_batch, metrics_g_batch))
        
        
        
        with torch.no_grad():
            outputs_g = generator(INPUTS_G)
            outputs_d = discriminator(outputs_g)
            
            grid_img_list.append(torchvision.utils.make_grid(outputs_g.cpu(), nrow=10, normalize=True, pad_value=1))
    
            plt.figure(figsize=(20, 2), dpi=80)
            for i, (img, lab) in enumerate(zip(outputs_g[:16], outputs_d[:16])):
                plt.subplot(1, 16, i+1)
                plt.imshow(img.cpu().permute(1, 2, 0), cmap=plt.cm.binary)
                plt.title("%.2f" % lab.cpu().item())
                plt.axis("off")
            plt.tight_layout()
            plt.show()
    
  13. 绘制损失值和评估指标

    plt.figure(figsize=(12, 4), dpi=80)
    
    plt.subplot(1, 2, 1)
    plt.plot(loss_d_list, label="discriminator_loss")
    plt.plot(loss_g_list, label="generator_loss")
    plt.title("Loss of discriminator and generator")
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(metrics_d_list, label="discriminator_metrics")
    plt.plot(metrics_g_list, label="generator_metrics")
    plt.title("Metrics of discriminator and generator")
    plt.xlabel("epochs")
    plt.ylabel("metrics")
    plt.legend()
    
    plt.show()
    

GAN生成对抗网络

  1. 绘制动态的GAN图像生成过程

    fig = plt.figure(figsize=(10, 10), dpi=80)
    plt.axis("off")
    
    imgs = [[plt.imshow(np.transpose(img, (1, 2, 0)), animated=True)] for img in grid_img_list]
    ani = animation.ArtistAnimation(fig, imgs, interval=1000, repeat_delay=1000, blit=True)
    
    HTML(ani.to_jshtml())
    
  2. 绘制真实图片和GAN生成图片

    plt.figure(figsize=(20, 10), dpi=80)
    
    plt.subplot(1, 2, 1)
    plt.title("real digits image")
    plt.imshow(torchvision.utils.make_grid(img_batch.cpu(), nrow=10, normalize=True, pad_value=1).permute(1, 2, 0))
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.title("fake digits image")
    plt.imshow(np.transpose(grid_img_list[-1], (1, 2, 0)))
    plt.axis("off")
    

GAN生成对抗网络

本文地址:https://blog.csdn.net/bash_winner/article/details/113998997