DSGAN退化网络
程序员文章站
2022-03-01 14:58:26
...
非成对的退化
1.基本结构
-
generator
1个conv + 8个resblock + 1个conv
Generator(
(block_input): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): PReLU(num_parameters=1)
)
(res_blocks): ModuleList(
(0): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(1): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(2): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(3): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(4): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(5): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(6): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(7): ResidualBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(prelu): PReLU(num_parameters=1)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(block_output): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
-
discriminator
DiscriminatorBasic(
(net): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2)
(2): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2)
(5): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2)
)
(gan_net): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))
)
2. 损失函数
损失函数的示意图图下:
2.1 生成器
(1)感知损失(perceptual_loss)
- 输入:Xb(经过HR下采样得到的初步的LR)与Xd(经过GAN生成的fake LR img)
- 目的:使退化后的图像Xd与初始的Xb保持风格上的一致性。
- 实现:采用VGG来计算
(2)颜色损失(color loss)
- 输入:Xb(经过HR下采样得到的初步的LR)与Xd(经过GAN生成的fake LR img)
- 目的:使退化后的图像Xd与初始的Xb在颜色上保持有一定的相似度,不能因此退化改变了颜色。
- 实现:首先采用低通滤波,之后再求L1 loss,公式如下。作者认为,低通滤波保存着图像的颜色信息,在代码的实现中采用average pooling,k=5,stirde = 1。作者还说低通的实现可以采用多种方式,不拘泥于average pooling
(3)GAN损失
- 输入:真实的LR(z)和Xd(经过GAN生成的fake LR img)
- 目的:论文中认为生成Xb首先经过了下采样,降采样过程消除了高图像频率,并将低频信息保持在减少的像素数内。这导致了高频特征的丢失,而低频信息,如颜色和背景仍然存在。因此采用高通滤波得到Xd和z的高频信息。对高通滤波的后的图像进行判别。
- 实现:高频图像就是原图减去低频图像。总的公式如下:
总的生成器损失:
2.2 判别器
- 输入:真实的LR(z)和Xd(经过GAN生成的fake LR img)
- 实现:高频图像就是原图减去低频图像。标准交叉熵损失。
3.训练数据及参数设置
3.1 证件照超分任务
(1)数据
- LR:低清的证件照,align到128×128
- HR:高清的证件照,align到256×256
将整个图像作为输入,训练dsgan
(2)参数设置
- 学习率:0.0002
- 总iter:8w
4. 评价指标
详情见文章:FID评价指标
Frechet Inception 距离得分(Frechet Inception Distance score,FID)是计算真实图像和生成图像的特征向量之间距离的一种度量。
目前人脸的fid为9.8左右
上一篇: 2022年解决Coursera的课程视频无法观看问题
下一篇: bootstrap下拉多选框实现