欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Generative Adversarial Networks

程序员文章站 2023-12-24 20:52:57
...

Generative Adversarial Networks

上一篇讲述了VAEs(变分自编码器),那么这次继续学习一下另一个生成模型——GANs。这里建议如果没有看VAEs的请点击传送门:,因为有所关联,所以如果直接看这篇的话,开头会有点奇怪。

从VAEs继续。

如果我们并不想明确模型的密度分布,我们只把注意点放在抽样生成一个新的实例呢?

但问题是,我们无法直接从复杂的分布中抽样。解决办法就是,我们从一个简单的噪声分布中抽样,然后基于这个噪声去训练一个分布。

Generative Adversarial Networks

那如何去使得生成的实例跟我们输入的实例尽可能相近呢?

加入一个鉴别器!鉴别器的功能就是来区分输入判别网络的实例是真实的样本,还是由生成网络生成的样本。因此,很显然两个网络之间形成了一种很特殊的彼此竞争的关系,生成网络尽可能想生成“骗过”鉴别器的实例,而鉴别器又尽可能想“不被骗”,识别出哪些实例是被生成网络生成的(假的),哪些实例是真实的。

Generative Adversarial Networks

来看个例子

生成器:尽力提高数据的虚假性。

鉴别器:尽力辨别真假。

首先,最初的生成器是由噪声抽样形成的,因此他的点的分布是随机的;而判别器有一个很明确的判别标准,真实点的分布被鉴别为1,生成点被判别为0。

Generative Adversarial Networks

下一步,生成器发现并没有很好地欺骗判别器,因此在下一次迭代的时候,生成器生成的点逐步向判别器判别为真实的点靠近。

Generative Adversarial Networks

下一步,轮到了判别器,由于生成器所生成的点逐步接近真实点,因此判别器可能有点难度去区分。

Generative Adversarial Networks

再下一步,生成器生成的点更接近真实点,甚至已经有几个生成点与真实点几乎重合,这对接下来判别器的判别会带来极大的难度。

Generative Adversarial Networks

如此迭代下去。。。

通过这个过程,我们可以发现,训练这样的两个模型的大方法就是:单独交替迭代训练。什么意思?因为是2个网络,不好一起训练,所以才去交替迭代训练。

如何训练一个GANs

训练GANs需要连续经过一个最大最小优化问题,其实也就是对应上面描述的迭代优化的过程。

Generative Adversarial Networks

鉴别器希望最大化目标,且D(x)趋向于1,D(G(x))趋向于0;而生成器希望最小化目标,且D(G(x))趋向于1。

优化D:

可以看到,优化D的时候,也就是判别网络,其实没有生成网络什么事,后面的G(z)这里就相当于已经得到的假样本。优化D的公式的第一项,使得真样本x输入的时候,得到的结果越大越好,可以理解,因为需要真样本的预测结果越接近于1越好嘛。对于假样本,需要优化是的其结果越小越好,也就是D(G(z))越小越好,因为它的标签为0。但是呢第一项是越大,第二项是越小,这不矛盾了,所以呢把第二项改成1-D(G(z)),这样就是越大越好,两者合起来就是越大越好。

Generative Adversarial Networks

优化G:

那么同样在优化G的时候,这个时候没有真样本什么事,所以把第一项直接去掉了。这个时候只有假样本,但是我们说这个时候是希望假样本的标签是1的,所以是D(G(z))越大越好,但是呢为了统一成1-D(G(z))的形式,那么只能是最小化1-D(G(z)),本质上没有区别,只是为了形式的统一。之后这两个优化模型可以合并起来写,就变成了最开始的那个最大最小目标函数了。

Generative Adversarial Networks

GANs强大的功能

这张图表明的是GAN的生成网络如何一步步从均匀分布学习到高斯分布的。原始数据x服从正太分布,这个过程你也没告诉生成网络说你得用高斯分布来学习,但是生成网络学习到了。假设你改一下x的分布,不管什么分布,生成网络可能也能学到。这就是GAN可以自动学习真实数据的分布的强大之处。

Generative Adversarial Networks

再来看一个,下面是一个改变人脸的过程,最上面一行是输入的图像,通过GANs,我们可以看到对应每一列生成的图像。

Generative Adversarial Networks

代码

from __future__ import absolute_import, division, print_function, unicode_literals
from tensorflow.keras import layers
from IPython import display
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import time
#数据数量
BUFFER_SIZE = 2656
#批大小
BATCH_SIZE = 16
#迭代次数
EPOCHS = 800
#G的测试噪音大小
noise_dim = 100
#G的测试噪音集大小
num_example_to_G = 4
#dropout_rate
DROPOUT = 0.3
#G的测试集
seed = tf.random.normal([num_example_to_G, noise_dim])
#导入数据集
data = np.load('./paints200.npy')
data = data.astype('float32')
data = (data - 127.5) / 127.5
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#生成器与判别器
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(25*25*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((25, 25, 256)))
    assert model.output_shape == (None, 25, 25, 256) 
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 50, 50, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 100, 100, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 200, 200, 3)
    return model
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[200, 200, 3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(DROPOUT))
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(DROPOUT))
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(DROPOUT))
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model
generator = tf.keras.models.load_model('epoch7_Gmodel.h5')
discriminator = tf.keras.models.load_model('epoch7_Dmodel.h5')
#损失值
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)
#优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
#训练
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            train_step(image_batch)
            display.clear_output(wait=True)
            generate_and_save_images(generator, epoch+1, seed)
            """
            if (epoch + 1) % 100 == 0:
                generator.save('epoch'+epoch+'_Gmodel.h5')
                discriminator.save('epoch'+epoch+'_Dmodel.h5')
                """
            print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
        generator.save('epoch'+str(epoch+1)+'_Gmodel.h5')
        discriminator.save('epoch'+str(epoch+1)+'_Dmodel.h5')
        display.clear_output(wait=True)
        generate_and_save_images(generator, epochs, seed)
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(2,2))
    for i in range(predictions.shape[0]):
        plt.subplot(2, 2, i+1)
        plt.imshow((predictions[i, :, :, :]*127.5+127.5)/255)
        plt.axis('off')
    if not os.path.exists('picture'):
        os.mkdir('picture')
    plt.savefig('./picture/image_at_epoch_{:04d}.png'.format(epoch))
    #plt.show()
train(train_dataset, EPOCHS)
generator.save('newGmodel.h5')
discriminator.save('newDmodel.h5')

输出结果由gif展现:

Generative Adversarial Networks

上一篇:

下一篇: