tensorflow keras 关于CIFAR10数据集 CGAN的研究经验总结
程序员文章站
2022-03-09 13:41:07
...
前言
目前发现本人的网络可以在传统的MNIST手写数据集上有良好的表现,但是将其应用于CIFAR10数据集的时候,出现了非常严重的图像模糊行为,在实验了多种传统GAN的结构后,我的结论是传统的GAN对于图片的细节这些高频信息生成的能力非常欠缺的,现在我总结一下前期的工作,鉴于网上的代码比较少,我自己写了一个。
实验的结构
使用了类似于语义分割的结构进行实验,总之效果非常不好,虽然可以看出个大概的形状但是对于细节方面过于模糊。下面的结构仅供作为训练生成器的反面教材。如果想要生成MNIST的手写数据集的话,或者进行判别器的训练的话,拿去用用倒是不错。
代码
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.python.keras import backend as K
from tensorflow.keras.utils import plot_model
from IPython.display import Image
import cv2
import PIL
import json, os
import sys
import labelme
import labelme.utils as utils
import glob
import itertools
class GAN():
def __init__(self, # 定义全局变量
):
self.img_shape = (32, 32, 3)
self.save_path = r'C:\Users\Administrator\Desktop\photo\CGAN.h5'
self.img_path = r'C:\Users\Administrator\Desktop\photo'
self.batch_size = 100
self.test_size = 200
self.sample_interval = 1
self.epoch = 200
self.num_classes = 10
self.train_mode=0#0为从头训练 1为继续训练 不推荐继续训练 用本文的代码训练容易直接崩溃
# 建立GAN模型的方法
if self.train_mode==0:
self.generator_model = self.build_generator()
self.discriminator_model = self.build_discriminator()
self.model = self.bulid_model()
else:
self.model = keras.models.load_model(self.save_path)
self.generator_model = keras.Model(inputs=self.model.layers[1].input, outputs=self.model.layers[1].output)
self.discriminator_model = keras.Model(inputs=self.model.layers[2].input, outputs=self.model.layers[2].output)
def build_generator(self): # 生成器
input = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))
c00 = layers.UpSampling2D((2, 2))(input)
x = layers.Conv2D(256,(3,3),padding='same',activation='relu')(c00)
x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
c10 = layers.UpSampling2D((2, 2))(x)
c11 = layers.Conv2DTranspose(256, (8,8), strides=4, padding='same', activation='relu')(input)
x = layers.concatenate([c10,c11],axis=-1)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
c20 = layers.UpSampling2D((2, 2))(x)
c21 = layers.Conv2DTranspose(128, (16, 16), strides=8, padding='same', activation='relu')(input)
x = layers.concatenate([c20,c21],axis=-1)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
c30 = layers.UpSampling2D((2, 2))(x)
c31 = layers.Conv2DTranspose(64, (32, 32), strides=16, padding='same', activation='relu')(input)
x = layers.concatenate([c30,c31],axis=-1)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
output = layers.Conv2D(self.img_shape[2], (1, 1), padding='same', activation='sigmoid')(x)
model = keras.Model(inputs=input, outputs=output, name='generator')
model.summary()
return model
def build_discriminator(self): # 判别器
input = keras.Input(shape=self.img_shape)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(input)
x = layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(64, (3, 3), padding='same',activation='relu')(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(128, (3, 3), padding='same',activation='relu')(x)
x = layers.MaxPooling2D(2, 2)(x)
x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
x = layers.Conv2D(256, (3, 3), padding='same',activation='relu')(x)
x = layers.Flatten()(x)
output = layers.Dense(self.num_classes+1, activation='softmax')(x)
model = keras.Model(inputs=input, outputs=output, name='discriminator')
model.summary()
return model
def bulid_model(self): # 建立GAN模型
inputs = keras.Input(shape=(int(self.img_shape[0]/16),int(self.img_shape[1]/16),self.num_classes))
img = self.generator_model(inputs)
outputs = self.discriminator_model(img)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def load_data(self):
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
train_images,test_images = (train_images /255),(test_images/255)
if (self.img_shape[2]==1):
train_images = np.expand_dims(train_images, axis=-1)
test_images = np.expand_dims(test_images, axis=-1)
train_idx = np.arange(0,train_images.shape[0],1)
print('img train number:', train_images.shape)
print('img test number: ', test_images.shape)
return train_images,train_labels,train_idx,test_images,test_labels
def noise_generate(self,batch_size,labels):#噪声生成器
noise = np.random.normal(0, 1, (batch_size, int(self.img_shape[0]/16), int(self.img_shape[1]/16), self.num_classes)) # 生成标准的高斯分布噪声
n_labels = keras.utils.to_categorical(labels, num_classes=self.num_classes)
n_labels = n_labels.reshape(batch_size, 1, 1, self.num_classes)
noise = noise * n_labels#相乘后只有特定标签对应的层存在噪声
return noise
def compile(self):#模型编译
self.discriminator_model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.Adam(0.0001, 0.00001),
metrics=['categorical_accuracy'])
self.discriminator_model.trainable = False # 使判别器不训练
if self.train_mode ==1:
self.model=self.bulid_model()
print('continue train')
self.model.summary()
self.model.compile(optimizer=keras.optimizers.Adam(0.0001, 0.00001), loss='categorical_crossentropy', )
def train(self):
self.compile()
train_images,train_labels,train_idx,test_images,test_labels = self.load_data() # 读取数据
fake = np.ones((self.batch_size))*(self.num_classes)#判别为假 对应12
fake = keras.utils.to_categorical(fake,num_classes=self.num_classes+1)
step = int(train_images.shape[0] / self.batch_size) # 计算步长
print('step:', step)
for epoch in range(self.epoch):
train_idx = (tf.random.shuffle(train_idx)).numpy() # 每个epoch打乱一次
print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
if epoch % self.sample_interval == 0:#保存图片
self.generate_sample_images(epoch)
for i in range(step):
idx = train_idx[i * self.batch_size:i * self.batch_size + self.batch_size] # 生成索引
imgs = train_images[idx] # 读取索引对应的图片
labels = train_labels[idx]
#---------------------------------------------------------------生成标准的高斯分布噪声
noise = self.noise_generate(self.batch_size,labels)
gan_imgs = self.generator_model.predict(noise) # 通过噪声生成图片
# ---------------------------------------------------------------
labels=keras.utils.to_categorical(labels,num_classes=self.num_classes+1)#生成标签
total_imgs=tf.concat((gan_imgs,imgs),axis=0)
total_labels=tf.concat((fake,labels),axis=0)
# ----------------------------------------------训练判别器
discriminator_loss = self.discriminator_model.train_on_batch(total_imgs, total_labels)
# ----------------------------------------------- 训练生成器
generator_loss = self.model.train_on_batch(noise, labels)
print("epoch:%d step:%d [discriminator_loss: %f, acc: %.2f%%] [generator_loss: %f]" % (
epoch, i, discriminator_loss[0], 100 * discriminator_loss[1], generator_loss))
# print('val_acc', self.pred(mode=1, test_images=test_images, test_labels=test_labels))
self.model.save(self.save_path) # 每个epoch存储模型
print('save model')
def generate_sample_images(self, epoch):
row, col = 2, 2 # 行列的数字
labels = np.random.randint(0,self.num_classes,(row * col))
noise = self.noise_generate(row * col,labels) # 生成噪声
gan_imgs = ((self.generator_model.predict(noise)))
fig, axs = plt.subplots(row, col) # 生成画板
idx = 0
for i in range(row):
for j in range(col):
axs[i, j].imshow(gan_imgs[idx, :, :]) #cmap='gray')
axs[i, j].axis('off')
idx += 1
fig.savefig(self.img_path + "/%d.png" % epoch)
plt.close() # 关闭画板
def pred(self,mode,test_images=None,test_labels=None):#定义如何使用网络的函数
if (mode==0):#测试
model = keras.models.load_model(self.save_path)
print('loading model')
generator = keras.Model(inputs=model.layers[1].input, outputs=model.layers[1].output)
discriminator = keras.Model(inputs=model.layers[2].input, outputs=model.layers[2].output)
generator.summary()
discriminator.summary()
for i in range(10): #测试生成器
label = i
noise = self.noise_generate(1, label)
img = np.squeeze(generator.predict([noise]))
plt.imshow(img)
plt.show()
elif(mode==1): #验证集 让fake层的值为0查看判别器ACCC
print('testing')
step=int(test_images.shape[0]/self.test_size)
val_acc=0
for i in range(step):
pred=self.discriminator_model(test_images[i*self.test_size:(i+1)*self.test_size])
pred=pred.numpy()
pred[:,(self.num_classes)]=0
pred=tf.argmax(pred,axis=-1)
pred=keras.utils.to_categorical(pred,num_classes=self.num_classes+1)
labels=keras.utils.to_categorical(test_labels[i*self.test_size:(i+1)*self.test_size],num_classes=self.num_classes+1)
acc=1-tf.reduce_mean(tf.abs(pred-labels))
val_acc+=acc
val_acc=val_acc/step
return val_acc.numpy()
else:
pass
if __name__ == '__main__':
GAN = GAN()
GAN.train()
GAN.pred(mode=0)#预测 0模式
效果
虽然在生成器上失败了,但是在判别器的效果上做的很好,对于CIFAR10数据集验证集的分类精度达到了96% 虽然没有超过最高的97,但是我认为已经超过很多常规的算法了。生成的图片实在是不好意思拿出来,
上一篇: Qt输入中文出现乱码的解决方法
下一篇: 递归与回溯1:生成全部有效括号组合