欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

DSGAN退化网络

程序员文章站 2022-03-01 14:58:26
...

非成对的退化

1.基本结构

  • generator

1个conv + 8个resblock + 1个conv

Generator(
  (block_input): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=1)
  )
  (res_blocks): ModuleList(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (4): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (5): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (6): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (7): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu): PReLU(num_parameters=1)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (block_output): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
  • discriminator 

DiscriminatorBasic(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
  )
  (gan_net): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
)

2. 损失函数

损失函数的示意图图下:

DSGAN退化网络

2.1 生成器
(1)感知损失(perceptual_loss)

  • 输入:Xb(经过HR下采样得到的初步的LR)与Xd(经过GAN生成的fake LR img)
  • 目的:使退化后的图像Xd与初始的Xb保持风格上的一致性。
  • 实现:采用VGG来计算

 (2)颜色损失(color loss)

  • 输入:Xb(经过HR下采样得到的初步的LR)与Xd(经过GAN生成的fake LR img)
  • 目的:使退化后的图像Xd与初始的Xb在颜色上保持有一定的相似度,不能因此退化改变了颜色。
  • 实现:首先采用低通滤波,之后再求L1 loss,公式如下。作者认为,低通滤波保存着图像的颜色信息,在代码的实现中采用average pooling,k=5,stirde = 1。作者还说低通的实现可以采用多种方式,不拘泥于average pooling

DSGAN退化网络

(3)GAN损失

  • 输入:真实的LR(z)和Xd(经过GAN生成的fake LR img)
  • 目的:论文中认为生成Xb首先经过了下采样,降采样过程消除了高图像频率,并将低频信息保持在减少的像素数内。这导致了高频特征的丢失,而低频信息,如颜色和背景仍然存在。因此采用高通滤波得到Xd和z的高频信息。对高通滤波的后的图像进行判别。
  • 实现:高频图像就是原图减去低频图像。总的公式如下:

DSGAN退化网络

总的生成器损失:

DSGAN退化网络 

2.2 判别器

  • 输入:真实的LR(z)和Xd(经过GAN生成的fake LR img)
  • 实现:高频图像就是原图减去低频图像。标准交叉熵损失。

DSGAN退化网络

3.训练数据及参数设置

3.1 证件照超分任务

(1)数据

  •   LR:低清的证件照,align到128×128
  •   HR:高清的证件照,align到256×256

将整个图像作为输入,训练dsgan

(2)参数设置

  •   学习率:0.0002
  •   总iter:8w

4. 评价指标

详情见文章:FID评价指标

Frechet Inception 距离得分(Frechet Inception Distance score,FID)是计算真实图像和生成图像的特征向量之间距离的一种度量。

目前人脸的fid为9.8左右