欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

基于Pytorch的生成对抗网络源码(GAN)

程序员文章站 2023-12-31 20:33:34
...

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np

#读取数据
train_data = MNIST(root=’./mnist/’,train=True,transform=tfs.ToTensor())#60000张训练集
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[0].numpy())#生成第第三张图片,显示的为彩色图像
#plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=100, shuffle=True)#分批并打乱顺序

#定义生成器
class generator(nn.Module):
def init(self):
super(generator, self).init()

    self.encoder = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=3, stride=3, padding=1),  # (b, 16, 10, 10)
        nn.ReLU(),
        nn.MaxPool2d(2, stride=2),  # (b, 16, 5, 5)
        nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1),  # (b, 8, 3, 3)
        nn.ReLU(),
        nn.MaxPool2d(2, stride=1)  # (b, 8, 2, 2)
    )

    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2),  # (b, 16, 5, 5)
        nn.ReLU(),
        nn.ConvTranspose2d(16, 8, kernel_size=5, stride=3, padding=1),  # (b, 8, 15, 15)
        nn.ReLU(),
        nn.ConvTranspose2d(8, 1, kernel_size=2, stride=2, padding=1),  # (b, 1, 28, 28)
        nn.Tanh()
    )

def forward(self, x):
    encode = self.encoder(x)
    decode = self.decoder(encode)
    return encode, decode

#定义判别器
class discriminator(nn.Module):
def init(self):
super(discriminator, self).init()

    self.conv1 = nn.Sequential(
        nn.Conv2d(1, 32, 5,padding=2),  # (b, 32, 28, 28)
        nn.LeakyReLU(0.2),
        nn.AvgPool2d(2, stride=2),      # (b, 32, 14, 14)
    )
    self.conv2=nn.Sequential(
        nn.Conv2d(32,64,5,padding=2),    #(b,64,14, 14)
        nn.LeakyReLU(0.2),
        nn.AvgPool2d(2,stride=2)         #(b,64,7, 7)
    )

    self.fc=nn.Sequential(
        nn.Linear(64*7*7,1024),
        nn.LeakyReLU(0.2),
        nn.Linear(1024,1),
        nn.Sigmoid()
    )
def forward(self, x):
    x=self.conv1(x)
    x=self.conv2(x)
    x=x.view(x.size(0),-1)
    x=self.fc(x)
    return x

#定义参数
gen=generator()
dis=discriminator()
if torch.cuda.is_available():
gen.cuda()
dis.cuda()
loss_func=nn.BCELoss() #用于二分类的一个损失函数
a_loss_func=nn.MSELoss() #用于多分类的一个损失函数
d_optimizer=torch.optim.Adam(dis.parameters(),lr=0.0003)
g_optimizer=torch.optim.Adam(gen.parameters(),lr=0.0003)

#开始训练
for epoch in range(15):
D_loss=0
G_loss=0
A_loss=0
#判别器训练
for step, (img, label) in enumerate(train_loader): #可同时获得索引和值
#print(x.shape) #100,1,28,28
#print(label.shape)
#img = img.view(-1, 2828) # batch x, shape (batch, 2828)
size=img.shape[0] #100
#print(size)
real_img=Variable(img).cuda()
# 训练自动编码器
encoded, decoded = gen(real_img) # 生成假的图片
# decoded = Variable(decoded).cpu().long()
a_loss = a_loss_func(decoded, real_img)
g_optimizer.zero_grad()
a_loss.backward()
g_optimizer.step()
A_loss += a_loss.data[0]

    #print(b_x.shape)         #64*784
    real_label=Variable(torch.ones(size,1)).cuda()        #真的图片label则为1
    false_label=Variable(torch.zeros(size,1)).cuda()        #假的图片label则为0
    real_out=dis(real_img)
    d_loss_real=loss_func(real_out,real_label)       #输入真实图片的损失函数

    encoder, false_img = gen(real_img)  # 得到假的图片
    false_img = Variable(false_img).cuda()
    false_out = dis(false_img)
    d_loss_false = loss_func(false_out, false_label)  # 计算输入假的图片的损失函数

    d_loss = d_loss_real + d_loss_false  # 总的损失函数包括假的图片和真的图片分别产生的损失函数
    d_optimizer.zero_grad()  # 梯度清零
    d_loss.backward()  # 反向传播
    d_optimizer.step()  # 梯度优化更新判别器网络参数
    D_loss += d_loss.data[0]
    # if step % 100 == 0:  # 每100步显示一次
    #     print('Epoch: ', epoch, '| d_loss: %.4f' % d_loss.data.numpy())

    # 训练生成器
    encoded, decoded = gen(real_img)  # 生成假的图片
    output = dis(decoded)  # 生成假的图片丢进判别器当中
    # print(type(output))
    # output=torch.LongTensor(output)
    g_loss = loss_func(output, real_label)  # 计算生成器的损失函数  假的图片与真实label的loss
    g_optimizer.zero_grad()  # 梯度清零
    g_loss.backward()  # 反向传播
    g_optimizer.step()  # 梯度优化更新生成网络参数
    G_loss += g_loss.data[0]

    if step % 100 == 0:  # 每100步显示一次
        print('Epoch: ', epoch, '| d_loss: %.4f' % d_loss.data.cpu().numpy(), '| g_loss: %.4f' % g_loss.data.cpu().numpy()
                , '| a_loss: %.4f' % a_loss.data.cpu().numpy())
print('epoch: {}, D_Loss: {:.6f}, G_Loss: {:.6f},A_Loss:{:.6f}'
      .format(epoch, D_loss / len(train_loader), G_loss / len(train_loader),A_loss/len(train_loader)))

创建一个画布

f, a = plt.subplots(2, 10, figsize=(10, 2)) # 初始化数字 在图表中创建子图显示的图像是2行5列的
plt.ion()

用于查看原始数据

view_data = train_data.train_data[:10].view(-1, 1, 28, 28).type(torch.Tensor) / 255.
view_data=view_data.cuda()

for i in range(10):
a[0][i].imshow(np.reshape(view_data.data.cpu().numpy()[i], (28, 28)))
a[0][i].set_xticks(())
a[0][i].set_yticks(())

encoded_data, decoded_data = gen(view_data)
for i in range(10):
a[1][i].clear()
a[1][i].imshow(np.reshape(decoded_data.data.cpu().numpy()[i], (28, 28)))
a[1][i].set_xticks(())
a[1][i].set_yticks(())
plt.draw()
plt.pause(0.05) # 暂停0.05秒
plt.ioff()
plt.show()

上一篇:

下一篇: