欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【pytorch基础笔记五】基于条件GAN的色彩填充

程序员文章站 2022-05-20 19:34:41
...

【参考文献】
【1】《深入浅出GAN生成对抗网络》7.3 ColorGAN实现

训练过程如下图所示,然后自己本身对CGAN的任务主要两点:
1、生成器网络模型实际上是一个编解码器
2、条件因素相当于是对某一类特征进行了强调
PS:从实验结果看不理想,也可能是训练数据比较少的原因:(也或许是模型理解还有偏差。。。。
【pytorch基础笔记五】基于条件GAN的色彩填充

【pytorch基础笔记五】基于条件GAN的色彩填充


from random import randint
import numpy as np 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image


image_size = 128
batch_size = 1


"""
生成器模型 
e1: 第一次卷积输出,输入为 边缘图 + 噪声图 
e2: 第二次卷积输出,输入为 e1
e3: 第三次卷积输出,输入为 e2
e4: 第四次卷积输出,输入为 e3
e5: 第五次卷积输出,输入为 e4 

d4: 第一次反卷积,输入为 e5
d4 = d4 + e4
d5: 第二次反卷积,输入为 d4
d5 = d5 + e3 即考虑e4的特征
d6: 第三次反卷积,输入为 d5
d6 = d6 + e2
d7: 第四次反卷积,输入为 d6
d7 = d7 + e1
d8: 最后一次反卷积
tanh(t8)
"""

base_feature = 64

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.e1_layour = nn.Sequential(
            nn.Conv2d(3, base_feature, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))  
        self.e2_layour = nn.Sequential(
            nn.Conv2d(base_feature, base_feature*2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))
        self.e3_layour = nn.Sequential(
            nn.Conv2d(base_feature*2, base_feature*4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))
        self.e4_layour = nn.Sequential(
            nn.Conv2d(base_feature*4, base_feature*8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))
        self.e5_layour = nn.Sequential(
            nn.Conv2d(base_feature*8, base_feature*8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))
        
        self.d4_layour = nn.Sequential(
            nn.ConvTranspose2d(base_feature*8, base_feature*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 8),
            nn.ReLU(True)) 
        self.d5_layour = nn.Sequential(
            nn.ConvTranspose2d(base_feature*8, base_feature*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 4),
            nn.ReLU(True))
        self.d6_layour = nn.Sequential(
            nn.ConvTranspose2d(base_feature*4, base_feature*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 2),
            nn.ReLU(True))   
        self.d7_layour = nn.Sequential(
            nn.ConvTranspose2d(base_feature*2, base_feature, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature),
            nn.ReLU(True)) 
        self.d8_layour = nn.Sequential(
            nn.ConvTranspose2d(base_feature, 3, 4, 2, 1, bias=False),
            nn.Tanh()) 
            
    def forward(self, x):
        e1 = self.e1_layour(x) 
        e2 = self.e2_layour(e1) 
        e3 = self.e3_layour(e2) 
        e4 = self.e4_layour(e3) 
        e5 = self.e5_layour(e4) 
        

        d4 = self.d4_layour(e5)    
        d4 = torch.add(d4, e4) 
        d5 = self.d5_layour(d4)   
        d5 = torch.add(d5, e3)
        d6 = self.d6_layour(d5)
        d6 = torch.add(d6, e2)
        d7 = self.d7_layour(d6)
        d7 = torch.add(d7, e1) 
        d8 = self.d8_layour(d7)
        
        return d8 


"""
判别器模型 为单纯的卷积网络
torch.nn.Conv2d(in_channels, out_channels, 
kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
这里备注一下 out_channels
输出通道数代表输出的特征数量,某种意义上对应卷积核的数量,即一个卷积核对应一类特征

这里注意: 
h_out = (h_in + 2 * padding - dilation*(kernel_size - 1) - 1)/ stride  + 1
输出最后两个维度必须是 1 1

(128 + 2 - 3 - 1)/2 + 1 = 64
(64   - 2)/2 + 1 = 32
(32 - 2)/2 + 1   = 16
(16 - 2)/2 + 1   = 8
(8 - 2)/2 + 1    = 4
(4 - 4)/1 + 1    = 1

"""
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(3, base_feature, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_feature, base_feature * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_feature * 2, base_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_feature * 4, base_feature * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(base_feature * 8, base_feature * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(base_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(base_feature * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.dis(x)
        return x

    
"""
实例化网络
"""
d_learning_rate = 3e-4  # 3e-4
g_learning_rate = 3e-4
optim_betas     = (0.9, 0.999)
criterion       = nn.BCELoss()  #损失函数 - 二进制交叉熵
G = Generator()
D = Discriminator()

if torch.cuda.is_available():
    print("use cuda")
    D = D.cuda()
    G = G.cuda()


g_optimizer = optim.Adam(G.parameters(), lr=d_learning_rate)
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate)

    
"""
预处理
1. 从\CelebA\数据集中选取200张人像图
2. 取180张作为训练集,取20张作为测试集
3. 对训练集的人像进行边缘检测生成边缘图
4. 对训练集的人像进行模糊处理生成噪声图(训练的噪声图应该每轮更新)
5. 对测试集的人像进行边缘检测生成边缘图
6. 对测试集的人像进行模糊处理生成噪声图
"""
ori_file_path = 'E:\dataset\CelebA\Img\label1'
tar_file_path = 'E:\dataset\colorgan'

def _copy_ori_file(filetype, filename):
    prefix, suffix = filename.split('.')  
    old_file = ori_file_path + "\\" + filename
    new_file = tar_file_path + "\\" + filetype + "\\ori\\" + filename   
    shutil.copyfile(old_file, new_file)

def _copy_edge_file(filetype, filename):
    prefix, suffix = filename.split('.')  
    old_file = ori_file_path + "\\" + filename
    img = cv2.imread(old_file)
    #print(img.shape) (218, 178, 3)
    img_gray  = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    base_edge = cv2.adaptiveThreshold(img_gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY, blockSize=7, C=5)
    new_file = tar_file_path + "\\" + filetype + "\\edge\\" + filename   
    cv2.imwrite(new_file, base_edge)

def _copy_blur_file(filetype, filename):
    prefix, suffix = filename.split('.')  
    old_file = ori_file_path + "\\" + filename
    img = cv2.imread(old_file)
    img = np.fliplr(img.reshape(-1, 3)).reshape(img.shape)
    for i in range(5):
        randx = randint(0, 205)
        randy = randint(0, 205)
        img[randx:randx+50, randy:randy+50] = 255
    blur = cv2.blur(img, (100, 100))
    new_file = tar_file_path + "\\" + filetype + "\\blur\\" + filename 
    cv2.imwrite(new_file, blur)


def _do_preprocess():
    ori_files = os.listdir(ori_file_path)
    
    for index in range(len(ori_files)):
            
        if index < 2:    
            _copy_ori_file('train',ori_files[index])
            _copy_edge_file('train',ori_files[index])
            #_copy_train_blur_file(ori_files[index]) 噪声图像应该每次独立生成
        elif index < 1:
            _copy_ori_file('test',ori_files[index])
            _copy_edge_file('test',ori_files[index])#对于最终测试集,噪声应该是随机噪声

#_do_preprocess()
"""
数据加载 
""" 
transform=transforms.Compose([transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

def _get_fake_imgs(batch_size = 64):
    """
    
    1. 从train路径中随机取出 batch_size 张 【边缘图片】
    2. 根据边缘图片的原始图片生成 batch_size 张【噪声图片】
    3. 叠加 【边缘图片】 和 【噪声图片】
    4. 对叠加后的图片进行处理,如下:
        transforms.Resize(image_size)
        transforms.CenterCrop(image_size)
        transforms.ToTensor()
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    """
    result = []
    
    edge_file_path = tar_file_path + "\\train\\edge\\"
    ori_file_path  = tar_file_path + "\\train\\ori\\"
    edge_list = os.listdir(edge_file_path)
    #生成一个随机序列
    numlist = random.sample(range(0,len(edge_list)), batch_size)
    for i in numlist:
        
        _filename = edge_list[i]
        edge_file = edge_file_path + _filename
        ori_file  = ori_file_path  + _filename
        edge_img  = cv2.imread(edge_file)
        ori_img   = cv2.imread(ori_file)
        ori_img = np.fliplr(ori_img.reshape(-1, 3)).reshape(ori_img.shape)
        for i in range(5):
            randx = randint(0, 205)
            randy = randint(0, 205)
        ori_img[randx:randx+50, randy:randy+50] = 255
        blur_img = cv2.blur(ori_img, (100, 100))
        
        img_combine = cv2.addWeighted(edge_img, 0.8, blur_img, 0.2, 0)
        pil_img_combine = Image.fromarray(img_combine) #需要转换为PIL格式图片
        res_img = transform(pil_img_combine)
        result.append(torch.unsqueeze(res_img, 0))
    
    return torch.cat(result,dim=0)
        
    
def _get_real_imgs(batch_size = 64):
    result = []
    ori_file_path  = tar_file_path + "\\train\\ori\\"
    ori_list = os.listdir(ori_file_path)
    numlist = random.sample(range(0,len(ori_list)), batch_size)
    for i in numlist:
        _filename = ori_list[i]
        ori_file  = ori_file_path  + _filename
        ori_img   = cv2.imread(ori_file)
        pil_img_combine = Image.fromarray(ori_img)
        res_img = transform(pil_img_combine)
        result.append(torch.unsqueeze(res_img, 0))
    return torch.cat(result,dim=0)

def _get_train_image(batch_size = 64):
    """
    修改自 _get_fake_imgs 和 _get_real_imgs ,%%bash证同一轮训练伪造图片与真实图片存在对应关系 
    """
    
    fake_result = []
    ori_result  = []
    edge_result = []
    
    edge_file_path = tar_file_path + "\\train\\edge\\"
    ori_file_path  = tar_file_path + "\\train\\ori\\"
    edge_list = os.listdir(edge_file_path)
    #生成一个随机序列
    numlist = random.sample(range(0,len(edge_list)), batch_size)
    for i in numlist:
        
        _filename = edge_list[i]
        edge_file = edge_file_path + _filename
        ori_file  = ori_file_path  + _filename
        edge_img  = cv2.imread(edge_file)
        ori_img   = cv2.imread(ori_file)
        ori_img   = np.fliplr(ori_img.reshape(-1, 3)).reshape(ori_img.shape)
        for i in range(5):
            randx = randint(0, 205)
            randy = randint(0, 205)
        ori_img[randx:randx+10, randy:randy+10] = 255
        blur_img = cv2.blur(ori_img, (100, 100))
        
        img_combine = cv2.addWeighted(edge_img, 0.5, blur_img, 0.5, 0)
        pil_img_combine = Image.fromarray(img_combine) #需要转换为PIL格式图片
        res_img = transform(pil_img_combine)
        fake_result.append(torch.unsqueeze(res_img, 0))
        
        ori_output    = cv2.imread(ori_file)
        ori_output    = cv2.cvtColor(ori_output, cv2.COLOR_BGR2RGB)
        ori_pil_img   = Image.fromarray(ori_output)
        ori           = transform(ori_pil_img)
        ori_result.append(torch.unsqueeze(ori, 0))
        
        edge_pil_img = Image.fromarray(edge_img)
        edge_img     = transform(edge_pil_img)
        edge_result.append(torch.unsqueeze(edge_img, 0))
          
    return torch.cat(ori_result,dim=0), torch.cat(fake_result,dim=0), torch.cat(edge_result,dim=0)
    

    
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)  
    out = out.view(-1, 3, image_size, image_size) 
    return out
"""
CGAN训练逻辑:
【1】《深入浅出GAN生成对抗网络》7.3 ColorGAN实现

1. 获取线条图像 edge_image;
2. 获取模糊图像 blur_image
3. 生成合并图像 combine_image  = edge_image + blur_image
4. 生成图像为   generate_image = G(combine_image)
5. real_to_d = real_image + combine_image
6. fake_to_d = generate_image + combine_image
7. 对判别器优化:real_to_d 判真 fake_to_d 判真
8. fake_to_g = generate_image_1 + combine_image
8. 对生成器优化:fake_to_g 判真

备注:
20200723调整:
20200724调整:cv2 未转rgb

"""

def _show_test_process_data(image):
    #用以显示tensor的图像数据,测试用
    data = image[0].numpy()
    data = data.transpose((1, 2, 0))
    print(data)
    #data = (data + 1) *0.5 * 255
    print(data)
    plt.imshow(data)
    

num_epochs = 5000  #循环次数
for epoch in range(num_epochs): 
    
    #第一步:训练判别器
    real_image, combine_image, edge_image = _get_train_image(batch_size)

    real_label = torch.full((batch_size,), 1).cuda()
    fake_label = torch.full((batch_size,), 0).cuda()

    generate_image = G(Variable(combine_image).cuda()).detach()
    
    #_show_test_process_data(combine_image)
    
    
    real_to_d      = torch.add(Variable(combine_image).cuda(), Variable(real_image).cuda())/2
    fake_to_d      = torch.add(Variable(combine_image).cuda(), generate_image)/2

    d_real_decision = D(real_to_d)
    d_fake_decision = D(fake_to_d)

    d_real_loss   = criterion(d_real_decision.cuda(), real_label)
    d_fake_loss   = criterion(d_fake_decision.cuda(), fake_label)
    d_loss = d_real_loss + d_fake_loss
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step() 
    
    #第二步:训练生成器
    generate_image_1  = G(Variable(combine_image).cuda())
    fake_to_g         = torch.add(generate_image_1, Variable(combine_image).cuda())/2
    g_real_decision   = D(fake_to_g)
    g_real_label      = torch.full((batch_size,), 1).cuda()
    g_fake_loss       = criterion(g_real_decision.cuda(), g_real_label)
    g_optimizer.zero_grad()
    g_fake_loss.backward()
    g_optimizer.step() 

    if epoch % 500 == 0 or epoch == 0:  
        print("Epoch[{}],g_fake_loss:{:.6f} ,d_loss:{:.6f}"
              .format(epoch,g_fake_loss.data.item(),d_loss.data.item()))
        output = to_img(generate_image_1)
        save_image(output, './img/cgan/test_' + str(epoch) + '.png')
        
        test = []
        test_img      = cv2.imread('./img/cgan/0000001.jpg')
        test_img      = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
        test_pil_img  = Image.fromarray(test_img)
        test_input    = transform(test_pil_img)
        test.append(torch.unsqueeze(test_input, 0))
        data = torch.cat(test,dim=0)
        generate_res  = G(Variable(data).cuda())
        save_image(to_img(generate_res), './img/cgan/result.jpg')