欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

使用对抗loss来提升图像的视觉效果

程序员文章站 2022-03-20 13:30:56
...

判别器

为了使用对抗(adversarial)loss,第一步当然是创建一个判别器了,这里就先放一个最常规最简单的判别器:

class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ngf=32):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_nc, ngf, normalize=False),
            *discriminator_block(ngf, ngf*2),
            *discriminator_block(ngf*2, ngf*4),
            # *discriminator_block(ngf*4, ngf*8), # 根据需求看是否需要增加判别器的复杂性
            # nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(ngf*4, 1, 4, padding=1)
        )
    
    def forward(self, img):
        x = self.model(img)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

ReplayBuffer

class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

训练过程

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
target_real = Variable(Tensor(cfg.TRAIN.batch_size, 1).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(cfg.TRAIN.batch_size, 1).fill_(0.0), requires_grad=False)
fake_buffer = ReplayBuffer()

net_D = Discriminator(input_nc=1, ngf=32).to(torch.device('cuda:0'))

optimizer_G = optim.Adam(net_D.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(net_D.parameters(), lr=0.0001, betas=(0.5, 0.999))
criterion_GAN = torch.nn.MSELoss()

for iters in range(1, total_iterations):
	input, target = train_provider.next()
	# 更新生成器
	optimizer_G.zero_grad()
	pred = model(input)
	
	# 生成器loss
	D_pred = net_D(pred)
	loss_G = criterion_GAN(D_pred, target_real)
	loss_G .backward()
	optimizer_G .step()
	
	# 更新判别器
	optimizer_D.zero_grad()
	# Real loss
    pred_real = net_D(target)
    loss_D_real = criterion_GAN(pred_real, target_real)
    # fake loss
    pred = fake_buffer.push_and_pop(pred)
    pred_fake = net_D(pred.detach())
    loss_D_fake = criterion_GAN(pred_fake, target_fake)
    # Total loss
    loss_ad = (loss_D_real + loss_D_fake)*0.5
    loss_ad.backward()
    optimizer_D.step()