欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

DCGAN从入门到放弃—生成炮姐头像

程序员文章站 2022-04-16 16:38:59
训练数据集:待审核数据集包括286张头像(像素150*150 300*300),训练时根据实际情况取用。...

问题描述

训练数据集:https://download.csdn.net/download/Megurine_Luka_/13197335

数据集包括286(资源那里打错字了)张头像(像素150*150 300*300),图像有重复,训练时根据实际情况取用。

本项目希望通过训练DCGAN,使其能够生成御坂美琴头像。我真的不是炮厨。

GAN用于生成图像,关于其训练过程,许多博客讲得都十分清楚,这里不再赘述。

关于其正确性的理论证明,这个视频讲得比较清楚:https://www.bilibili.com/video/BV1eE411g7xc?from=search&seid=6958078989660239757

看完证明,似乎觉得GAN无可挑剔了,样本图像的Dx值(即被判别器判定为“真品”的概率,范围0到1,其值越高,说明判别器认为其越接近真品,下同),与生成器生成图像的Dx值,最终必定会收敛于0.5,然后训练结束,生成器就可以产生想要的图像了。

然而这一切才是噩梦的开始(手动滑稽)。

先放一下代码,代码采用Pytorch。

代码

项目结构

__pycache__是运行时生成的,不用管。

DCGAN从入门到放弃—生成炮姐头像

model.py

模型是DCGAN,即判别器采用卷积神经网络,而生成器则采用反卷积。

关于反卷积的理解,详见https://blog.csdn.net/lanadeus/article/details/82534425。这里插个题外话,如果说,卷积运算可以看作乘上一个卷积矩阵(而且是个稀疏矩阵)的话,反卷积就是乘上卷积矩阵的逆矩阵(这也是个稀疏矩阵),反卷积层与相应的(或者说,输入输出同规模的?)卷积层所需训练的参数一样多。

判别器是输入一个144*144像素的图像,有3个RGB通道,输出一个Dx值,生成器是输入一个1*nz的噪声,输出一个144*144的3通道RGB图像。判别器与生成器的结构一一对应。

#coding:utf8
from torch import nn

class NetG(nn.Module):
    '''
    生成器定义
    '''
    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器feature map数
        
        self.features = nn.Sequential(
            # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
            nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16
            
            nn.ConvTranspose2d(ngf * 2, ngf, 5, 3, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 48 x 48
        )
        self.classifier = nn.Sequential(
            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
            # 输出形状:3 x 144 x 144
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class NetD(nn.Module):
    '''
    判别器定义
    '''
    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        
        self.features = nn.Sequential(
           # 输入 3 x 144 x 144
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) * 48 * 48
            
            nn.Conv2d(ndf, ndf * 2, 5, 3, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf * 2) x 16 x 16
            
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 8 x 8
            
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 4 x 4
            
        )
        self.classifier = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()  # 输出一个数(概率)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

main.py

#coding:utf8
import torch as t
import torchvision as tv
from model import NetG, NetD
from torch.autograd import Variable   
from config import Config
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor,ToPILImage
import matplotlib.pyplot as plt

#设置参数
opt = Config()
#定义网络
netg, netd = NetG(opt), NetD(opt)

to_tensor=ToTensor()
to_pil=ToPILImage()

#图像还原(消去标准化的影响),没有此步,绘制出的图像的颜色不太正常
def trans(img):
    img=img/2+0.5
    return img

def shapeout(output):
    sf=np.zeros(opt.batch_size)
    for i in range(opt.batch_size):
        sf[i]=output[i][0][0]
    return sf

#数据预处理
def datasets():
    transforms = tv.transforms.Compose([
                    #将图像修正为144*144像素的
                    tv.transforms.Scale(opt.image_size),
                    tv.transforms.CenterCrop(opt.image_size),
                    #将各个像素值限定在[0,1]的范围内
                    tv.transforms.ToTensor(),
                    #中心标准化,经此处理的像素值将限定在[-1,1]内
                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
    dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
    dataloader = t.utils.data.DataLoader(dataset,
                                         batch_size = opt.batch_size,
                                         shuffle = True,
                                         num_workers= opt.num_workers,
                                         drop_last=True
                                         )
    return dataloader

def train(dataloader):
    
    #定义优化器
    optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))  
    #optimizer_g = t.optim.RMSprop(params=netg.parameters(),lr=opt.lr1,weight_decay=1e-9)
    #optimizer_d = t.optim.RMSprop(params=netg.parameters(),lr=opt.lr1,weight_decay=1e-9)
    
    #损失函数
    criterion = t.nn.BCELoss()

    # 真图片label为1,假图片label为0
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    # noises为生成网络的输入
    noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
             
    for epoch in range(opt.max_epoch):
        
        print(epoch)
        for i,(img,_) in enumerate(dataloader):
            
            #真实图片
            real_img = Variable(img)
            
            #print(i)
            #训练判别器
            if epoch%1==0:
                #将真图片判别为正确
                output1 = netd(real_img)
                error_d_real = criterion(output1,true_labels)
                #生成噪声
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                #将噪声投喂到生成器,生成图像
                fake_img = netg(noises)
                #将生成器生成的图像判别为错误
                output2 = netd(fake_img.detach())
                error_d_fake = criterion(output2,fake_labels)
            
                optimizer_d.zero_grad()     #梯度清零
                d_loss=error_d_fake+error_d_real
                d_loss.backward()
    
                optimizer_d.step()
                
                #输出dx值
                print('真图')
                print(shapeout(output1))
                print('假图')
                print(shapeout(output2))
            
            #训练生成器
            if epoch%1==0:
                
                #利用噪声生成图像
                noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
                fake_img = netg(noises)
                #投喂到判别器
                output = netd(fake_img)
                error_g = criterion(output,true_labels)
                #梯度清零
                optimizer_g.zero_grad()
                error_g.backward()
                optimizer_g.step()

            
def generate():
    '''
    随机生成动漫头像,并根据netd的分数选择较好的
    '''
    
    #验证模式    
    netg.eval()
    netd.eval()
    
    noises = t.randn(opt.gen_num,opt.nz,1,1).normal_(opt.gen_mean,opt.gen_std)
    noises = Variable(noises, volatile=True)
        
    # 生成图片
    fake_img = netg(noises)
    c=1
    for i in range(opt.gen_num):
        img=fake_img.data[i].data
        img=trans(img)
        img=to_pil(img)
        #绘制图像
        plt.figure(c)
        plt.imshow(img)
        c=c+1

if __name__ == '__main__':
    #加载数据
    dataloader=datasets()    
    #加载模型
    #netd.load_state_dict(t.load(opt.model_path1)) 
    #netg.load_state_dict(t.load(opt.model_path2))
    #训练模型
    train(dataloader)
    #保存模型
    t.save(netd.state_dict(),opt.model_path3)
    t.save(netg.state_dict(),opt.model_path4)
    #测试
    generate()

config.py

这里说一下,学习率并不是一成不变的,需要根据实际情况调整。如果感觉判别器过强,则应削减其学习率。甚至可以在训练若干epoch后,保存模型,然后改变学习率,加载模型继续训练。而ngf与ndf参数的意义,了解不多,说不清楚。我看别人都写的64,我为了加快训练速度改成了32,好像也没什么影响。

class Config(object):
    #data_path = 'D:/TheMoth/keduoli/'  # 数据集存放路径
    data_path = 'D:/TheMoth/misaka/'
    
    num_workers = 1  # 多进程加载数据所用的进程数
    image_size = 144  # 图片尺寸
    batch_size = 32   #每批图片数
    max_epoch = 2500 
    lr1 = 2e-4  # 生成器的学习率
    lr2 = 5e-5 # 判别器的学习率
    beta1=0.5  # Adam优化器的beta1参数
    nz=40  # 噪声维度
    ngf = 32  # 生成器feature map数
    ndf = 32  # 判别器feature map数
    
    #模型加载、保存路径
    model_path1='D:/TheMoth/model/a1.pth'
    model_path2='D:/TheMoth/model/a2.pth'
    model_path3='D:/TheMoth/model/a1.pth'
    model_path4='D:/TheMoth/model/a2.pth'

    d_every = 1  # 每1个batch训练一次判别器
    g_every = 1  # 每1个batch训练一次生成器

    gen_num = 64 # 测试生成64张 
    gen_mean = 0  # 噪声的均值
    gen_std = 1  #噪声的方差

结果讨论

GAN的理论证明只能说,图像生成器在理论上是有可能实现的,(我感觉)它并没有还原GAN的训练过程。

GAN的成功训练,是建立在生成器、判别器能够相互对抗、齐头并进的基础之上的,在每次训练中,判别器甄别真品(样本图像)与赝品(生成器产生的图像)的能力都会上升一点,这样就导致了生成器再次生成的图像Dx值降低,生成器就会根据这个结果,调整参数,使得自己的Dx值试图高一些,接下来,判别器面临新的赝品,则需提高自己鉴别真品与新赝品的能力,如此往复。最终,所有的Dx值都会收敛于0.5,判别器无法分辨真品与赝品。被GAN♂挺了。

事实上,判别器跟生成器并不会乖乖地像预想的那样变化,它们随着训练过程的进行,总会出现一方过强而一方过弱的情况。而这也会导致训练无法进行下去。

比如说,如果判别器过强,那么它就会轻而易举地区分真品与赝品,样本的Dx值接近于1,而生成图像的Dx值则始终接近于0,这样,无论生成器怎样调整,判别器都不为所动(因为判别结果接近目标结果,误差小,通过反向传播导致的权值波动也会很小)。而缺乏指导的生成器,看着自己每次低低的Dx值,也会无所适从(这种情况下,由于结果几乎全错,误差大,反向传播导致的权值波动也超大,训练逐渐盲目),整个训练过程陷入停滞。甚至会产生模型的蜕化。

如果判别器过弱,那么,无论生成器生成多么离谱的图片,判别器都会打出较高的分数,这样训练也会陷入停滞,原因同上。

就是说,GAN训练的基础,是生成器、判别器达到某种“稳态”,知乎上讲是“纳什均衡”zhihu.com/question/304164323/answer/803644490,这种脆弱的均衡一旦被打破,模型将会受到难以恢复的损害。但生成器和判别器将会怎样发展下去,你知道么?神经网络本就是有点类似黑盒的存在(判别器可能看着还“顺眼”一点,生成器就很离谱)。实际上维持“纳什均衡”也是现在GAN所面临的重大难题。

而我遇到的问题就是,判别器总是过强,导致训练不得不终止,下面是生成图像中,效果最好的两张。扯这么多屁话就是给自己的失败找理由。

插曲:室友看了我生成器的生成结果,惊呼:我去这该不会是炮姐吧。我想,既然人肉眼能看分辨出这是个什么鬼东西来,这一星期的努力也没白费,心里也释然了。

DCGAN从入门到放弃—生成炮姐头像

DCGAN从入门到放弃—生成炮姐头像

不难发现,这两张图像其实最像数据集中出现最多的两张(数据集中有重复,我把比较清晰、有代表性的图片多复制了几次),数据集体量越小,生成的图像理论上就越专一。但这并不可行,因为如果数据集中全是一样的图像,那么判别器就会很轻易地训练完成,训练戛然而止(亲测)。

通过参考别人的项目,我发现他们的数据集体量超大,而且一般是加强判别器,比如说每训练五次判别器,才训练一次生成器。而非像我一样,通过削减学习率,抑制判别器。我觉得数据集越大,判别器就越难以训练,而生成器与判别器的失衡也许就是数据集太小所致。而这也是今后的改进方向吧,若有进展将更新。(截图工作是真的烦)心累,快被GAN吐血了。

如有错误请指正。以上观点只是我个人的胡诌八扯。

本文地址:https://blog.csdn.net/Megurine_Luka_/article/details/110246661