欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow keras 关于CIFAR10数据集 CGAN的研究经验总结

程序员文章站 2022-03-09 13:41:07
...

前言

目前发现本人的网络可以在传统的MNIST手写数据集上有良好的表现,但是将其应用于CIFAR10数据集的时候,出现了非常严重的图像模糊行为,在实验了多种传统GAN的结构后,我的结论是传统的GAN对于图片的细节这些高频信息生成的能力非常欠缺的,现在我总结一下前期的工作,鉴于网上的代码比较少,我自己写了一个。

实验的结构

使用了类似于语义分割的结构进行实验,总之效果非常不好,虽然可以看出个大概的形状但是对于细节方面过于模糊。下面的结构仅供作为训练生成器的反面教材。如果想要生成MNIST的手写数据集的话,或者进行判别器的训练的话,拿去用用倒是不错。

代码

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image

import cv2
import PIL
import json, os
import sys

import labelme
import labelme.utils as utils
import glob
import itertools


class GAN():
    def __init__(self,  # 定义全局变量
                 ):
        self.img_shape = (32, 32, 3)
        self.save_path = r'C:\Users\Administrator\Desktop\photo\CGAN.h5'
        self.img_path = r'C:\Users\Administrator\Desktop\photo'
        self.batch_size = 100
        self.test_size = 200
        self.sample_interval = 1
        self.epoch = 200
        self.num_classes = 10
        self.train_mode=0#0为从头训练 1为继续训练 不推荐继续训练 用本文的代码训练容易直接崩溃

        # 建立GAN模型的方法
        if self.train_mode==0:
            self.generator_model = self.build_generator()
            self.discriminator_model = self.build_discriminator()
            self.model = self.bulid_model()
        else:
            self.model = keras.models.load_model(self.save_path)
            self.generator_model = keras.Model(inputs=self.model.layers[1].input, outputs=self.model.layers[1].output)
            self.discriminator_model = keras.Model(inputs=self.model.layers[2].input, outputs=self.model.layers[2].output)
    def build_generator(self):  # 生成器

        input = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))

        c00 = layers.UpSampling2D((2, 2))(input)

        x = layers.Conv2D(256,(3,3),padding='same',activation='relu')(c00)
        x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)


        c10 = layers.UpSampling2D((2, 2))(x)
        c11 = layers.Conv2DTranspose(256, (8,8), strides=4, padding='same', activation='relu')(input)
        x = layers.concatenate([c10,c11],axis=-1)
        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)

        c20 = layers.UpSampling2D((2, 2))(x)
        c21 = layers.Conv2DTranspose(128, (16, 16), strides=8, padding='same', activation='relu')(input)
        x = layers.concatenate([c20,c21],axis=-1)
        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)


        c30 = layers.UpSampling2D((2, 2))(x)
        c31 = layers.Conv2DTranspose(64, (32, 32), strides=16, padding='same', activation='relu')(input)
        x = layers.concatenate([c30,c31],axis=-1)
        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
        x = layers.BatchNormalization(momentum=0.8)(x)


        output = layers.Conv2D(self.img_shape[2], (1, 1), padding='same', activation='sigmoid')(x)

        model = keras.Model(inputs=input, outputs=output, name='generator')
        model.summary()
        return model

    def build_discriminator(self):  # 判别器

        input = keras.Input(shape=self.img_shape)

        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(input)
        x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
        x = layers.MaxPooling2D(2, 2)(x)

        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
        x = layers.MaxPooling2D(2, 2)(x)

        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
        x = layers.MaxPooling2D(2, 2)(x)

        x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
        x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)


        x = layers.Flatten()(x)
        output = layers.Dense(self.num_classes+1, activation='softmax')(x)

        model = keras.Model(inputs=input, outputs=output, name='discriminator')
        model.summary()
        return model

    def bulid_model(self):  # 建立GAN模型
        inputs = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))
        img = self.generator_model(inputs)
        outputs = self.discriminator_model(img)
        model = keras.Model(inputs=inputs, outputs=outputs)
        return model

    def load_data(self):
        (train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
        train_images,test_images = (train_images /255),(test_images/255)

        if (self.img_shape[2]==1):
            train_images = np.expand_dims(train_images,  axis=-1)
            test_images = np.expand_dims(test_images, axis=-1)

        train_idx = np.arange(0,train_images.shape[0],1)
        print('img train number:', train_images.shape)
        print('img test number: ', test_images.shape)
        return train_images,train_labels,train_idx,test_images,test_labels

    def noise_generate(self,batch_size,labels):#噪声生成器
        noise = np.random.normal(0, 1, (batch_size, int(self.img_shape[0]/16), int(self.img_shape[1]/16), self.num_classes))  # 生成标准的高斯分布噪声
        n_labels = keras.utils.to_categorical(labels, num_classes=self.num_classes)
        n_labels = n_labels.reshape(batch_size, 1, 1, self.num_classes)
        noise = noise * n_labels#相乘后只有特定标签对应的层存在噪声
        return noise
    def compile(self):#模型编译

        self.discriminator_model.compile(loss='categorical_crossentropy',
                                         optimizer=keras.optimizers.Adam(0.0001, 0.00001),
                                         metrics=['categorical_accuracy'])
        self.discriminator_model.trainable = False  # 使判别器不训练
        if self.train_mode ==1:
            self.model=self.bulid_model()

            print('continue train')
        self.model.summary()
        self.model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.00001), loss='categorical_crossentropy', )

    def train(self):
        self.compile()
        train_images,train_labels,train_idx,test_images,test_labels = self.load_data()  # 读取数据
        fake = np.ones((self.batch_size))*(self.num_classes)#判别为假 对应12
        fake = keras.utils.to_categorical(fake,num_classes=self.num_classes+1)
        step = int(train_images.shape[0] / self.batch_size)  # 计算步长
        print('step:', step)

        for epoch in range(self.epoch):
            train_idx = (tf.random.shuffle(train_idx)).numpy()  # 每个epoch打乱一次
            print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
            if epoch % self.sample_interval == 0:#保存图片
                self.generate_sample_images(epoch)

            for i in range(step):
                idx = train_idx[i * self.batch_size:i * self.batch_size + self.batch_size]  # 生成索引
                imgs = train_images[idx]  # 读取索引对应的图片
                labels = train_labels[idx]

                #---------------------------------------------------------------生成标准的高斯分布噪声
                noise = self.noise_generate(self.batch_size,labels)
                gan_imgs = self.generator_model.predict(noise)  # 通过噪声生成图片

                # ---------------------------------------------------------------
                labels=keras.utils.to_categorical(labels,num_classes=self.num_classes+1)#生成标签
                total_imgs=tf.concat((gan_imgs,imgs),axis=0)
                total_labels=tf.concat((fake,labels),axis=0)

                # ----------------------------------------------训练判别器

                discriminator_loss = self.discriminator_model.train_on_batch(total_imgs, total_labels)

                # ----------------------------------------------- 训练生成器
                generator_loss = self.model.train_on_batch(noise, labels)

                print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
                    epoch, i, discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))

            # print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
            self.model.save(self.save_path)  # 每个epoch存储模型
            print('save model')
    def generate_sample_images(self, epoch):

        row, col = 2, 2 # 行列的数字
        labels = np.random.randint(0,self.num_classes,(row * col))

        noise = self.noise_generate(row * col,labels) # 生成噪声
        gan_imgs = ((self.generator_model.predict(noise)))
        fig, axs = plt.subplots(row, col)  # 生成画板
        idx = 0

        for i in range(row):
            for j in range(col):
                axs[i, j].imshow(gan_imgs[idx, :, :]) #cmap='gray')
                axs[i, j].axis('off')
                idx += 1
        fig.savefig(self.img_path + "/%d.png" % epoch)
        plt.close()  # 关闭画板

    def pred(self,mode,test_images=None,test_labels=None):#定义如何使用网络的函数
        if (mode==0):#测试
            model = keras.models.load_model(self.save_path)
            print('loading model')
            generator = keras.Model(inputs=model.layers[1].input, outputs=model.layers[1].output)
            discriminator = keras.Model(inputs=model.layers[2].input, outputs=model.layers[2].output)
            generator.summary()
            discriminator.summary()
            for i in range(10):  #测试生成器
                label = i
                noise = self.noise_generate(1, label)
                img = np.squeeze(generator.predict([noise]))
                plt.imshow(img)
                plt.show()
        elif(mode==1):  #验证集  让fake层的值为0查看判别器ACCC
            print('testing')
            step=int(test_images.shape[0]/self.test_size)
            val_acc=0
            for i in range(step):
                pred=self.discriminator_model(test_images[i*self.test_size:(i+1)*self.test_size])
                pred=pred.numpy()
                pred[:,(self.num_classes)]=0
                pred=tf.argmax(pred,axis=-1)
                pred=keras.utils.to_categorical(pred,num_classes=self.num_classes+1)
                labels=keras.utils.to_categorical(test_labels[i*self.test_size:(i+1)*self.test_size],num_classes=self.num_classes+1)
                acc=1-tf.reduce_mean(tf.abs(pred-labels))
                val_acc+=acc
            val_acc=val_acc/step
            return val_acc.numpy()
        else:
            pass

if __name__ == '__main__':
    GAN = GAN()
    GAN.train()
    GAN.pred(mode=0)#预测  0模式

效果

虽然在生成器上失败了,但是在判别器的效果上做的很好,对于CIFAR10数据集验证集的分类精度达到了96% 虽然没有超过最高的97,但是我认为已经超过很多常规的算法了。生成的图片实在是不好意思拿出来,

相关标签: GAN