欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

CycleGan图像风格迁移

程序员文章站 2022-04-09 20:25:16
...

简介

CycleGan全称: Cycle-Consistent Generative Adversarial Networks,循环一致性生成对抗网络,是一种Unpaired Image-to-Image Translation的深度学习网络。

效果展示

CycleGan图像风格迁移

网络结构

生成网络 G:X → Y
生成网络 F:Y → X
判别网络 DX:判别输入图像是否为X
判别网络 DY:判别输入图像是否为Y

前向cycle-consistency损失:x → G(x) → F(G(x)) \approx x
后向cycle-consistency损失:y → F(y) → G(F(y)) \approx y
CycleGan图像风格迁移
判别网络代码

self.d_A = self.build_discriminator()
self.d_B = self.build_discriminator()

img = Input(shape=self.img_shape)
# 64,64,64
d1 = conv2d(img, 64, normalization=False)
# 32,32,128
d2 = conv2d(d1, 128)
# 16,16,256
d3 = conv2d(d2, 256)
# 8,8,512
d4 = conv2d(d3, 512)
# 8,8,1
validity = Conv2D(1, kernel_size=3, strides=1, padding='same')(d4)
return Model(img, validity)

生成网络代码

self.g_AB = self.build_generator()
self.g_BA = self.build_generator()

img_A = Input(shape=self.img_shape)
img_B = Input(shape=self.img_shape)

fake_B = self.g_AB(img_A)
fake_A = self.g_BA(img_B)

reconstr_A = self.g_BA(fake_B)
reconstr_B = self.g_AB(fake_A)

img_A_id = self.g_BA(img_A)
img_B_id = self.g_AB(img_B)

model = get_resnet(self.img_rows,self.img_cols,self.channels)

Loss函数

1.优化目标
CycleGan图像风格迁移
2.总损失函数
其中 λ\lambda 控制G和F的相对重要性。
CycleGan图像风格迁移
3.对抗损失 LGAN
X和Y共享生成器和判别器,使生成的图片更接近于目标图片。
CycleGan图像风格迁移
4.循环一致性损失 Lcyc
使模型重构的图像F(G(x)) 与输入的图像x十分接近。
CycleGan图像风格迁移

训练过程

生成网络和判别网络交替训练。

生成模型训练

self.combined = Model(inputs=[img_A, img_B],
                     outputs=[valid_A, valid_B,
                               reconstr_A, reconstr_B,
                               img_A_id, img_B_id])
                                        
g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                      [valid, valid,
                                      imgs_A, imgs_B,
                                      imgs_A, imgs_B])

判别模型训练

fake_B = self.g_AB.predict(imgs_A)
fake_A= self.g_BA.predict(imgs_B)

dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

d_loss = 0.5 * np.add(dA_loss, dB_loss)

网络计算参数

判别网络:下采样

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 64, 64, 64)        3136      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 64, 64, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 32, 32, 128)       131200    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 16, 16, 256)       524544    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 8, 8, 512)         2097664   
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 8, 8, 512)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 8, 8, 1)           4609      
=================================================================
Total params: 2,761,153
Trainable params: 2,761,153
Non-trainable params: 0

生成网络:下采样 ——上采样

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 134, 134, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 128, 128, 64) 9472        zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 64) 0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 130, 130, 64) 0           activation_1[0][0]               
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 64, 64, 128)  73856       zero_padding2d_2[0][0]           
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 128)  0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 66, 66, 128)  0           activation_2[0][0]               
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 32, 32, 256)  295168      zero_padding2d_3[0][0]           
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 32, 32, 256)  0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 34, 34, 256)  0           activation_3[0][0]               
__________________________________________________________________________________________________
res0_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_4[0][0]           
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 32, 32, 256)  0           res0_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None, 34, 34, 256)  0           activation_4[0][0]               
__________________________________________________________________________________________________
res0_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_5[0][0]           
__________________________________________________________________________________________________
add_1 (Add)                     (None, 32, 32, 256)  0           res0_branch2c[0][0]              
                                                                 activation_3[0][0]               
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 32, 32, 256)  0           add_1[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, 34, 34, 256)  0           activation_5[0][0]               
__________________________________________________________________________________________________
res1_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_6[0][0]           
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 256)  0           res1_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_7 (ZeroPadding2D (None, 34, 34, 256)  0           activation_6[0][0]               
__________________________________________________________________________________________________
res1_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_7[0][0]           
__________________________________________________________________________________________________
add_2 (Add)                     (None, 32, 32, 256)  0           res1_branch2c[0][0]              
                                                                 activation_5[0][0]               
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 32, 32, 256)  0           add_2[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_8 (ZeroPadding2D (None, 34, 34, 256)  0           activation_7[0][0]               
__________________________________________________________________________________________________
res2_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_8[0][0]           
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 32, 32, 256)  0           res2_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_9 (ZeroPadding2D (None, 34, 34, 256)  0           activation_8[0][0]               
__________________________________________________________________________________________________
res2_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_9[0][0]           
__________________________________________________________________________________________________
add_3 (Add)                     (None, 32, 32, 256)  0           res2_branch2c[0][0]              
                                                                 activation_7[0][0]               
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 32, 32, 256)  0           add_3[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_10 (ZeroPadding2 (None, 34, 34, 256)  0           activation_9[0][0]               
__________________________________________________________________________________________________
res3_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_10[0][0]          
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 32, 32, 256)  0           res3_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_11 (ZeroPadding2 (None, 34, 34, 256)  0           activation_10[0][0]              
__________________________________________________________________________________________________
res3_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_11[0][0]          
__________________________________________________________________________________________________
add_4 (Add)                     (None, 32, 32, 256)  0           res3_branch2c[0][0]              
                                                                 activation_9[0][0]               
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 32, 32, 256)  0           add_4[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_12 (ZeroPadding2 (None, 34, 34, 256)  0           activation_11[0][0]              
__________________________________________________________________________________________________
res4_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_12[0][0]          
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 32, 32, 256)  0           res4_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_13 (ZeroPadding2 (None, 34, 34, 256)  0           activation_12[0][0]              
__________________________________________________________________________________________________
res4_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_13[0][0]          
__________________________________________________________________________________________________
add_5 (Add)                     (None, 32, 32, 256)  0           res4_branch2c[0][0]              
                                                                 activation_11[0][0]              
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 32, 32, 256)  0           add_5[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_14 (ZeroPadding2 (None, 34, 34, 256)  0           activation_13[0][0]              
__________________________________________________________________________________________________
res5_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_14[0][0]          
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 32, 32, 256)  0           res5_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_15 (ZeroPadding2 (None, 34, 34, 256)  0           activation_14[0][0]              
__________________________________________________________________________________________________
res5_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_15[0][0]          
__________________________________________________________________________________________________
add_6 (Add)                     (None, 32, 32, 256)  0           res5_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 32, 32, 256)  0           add_6[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_16 (ZeroPadding2 (None, 34, 34, 256)  0           activation_15[0][0]              
__________________________________________________________________________________________________
res6_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_16[0][0]          
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 32, 32, 256)  0           res6_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_17 (ZeroPadding2 (None, 34, 34, 256)  0           activation_16[0][0]              
__________________________________________________________________________________________________
res6_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_17[0][0]          
__________________________________________________________________________________________________
add_7 (Add)                     (None, 32, 32, 256)  0           res6_branch2c[0][0]              
                                                                 activation_15[0][0]              
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 32, 32, 256)  0           add_7[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_18 (ZeroPadding2 (None, 34, 34, 256)  0           activation_17[0][0]              
__________________________________________________________________________________________________
res7_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_18[0][0]          
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 32, 32, 256)  0           res7_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_19 (ZeroPadding2 (None, 34, 34, 256)  0           activation_18[0][0]              
__________________________________________________________________________________________________
res7_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_19[0][0]          
__________________________________________________________________________________________________
add_8 (Add)                     (None, 32, 32, 256)  0           res7_branch2c[0][0]              
                                                                 activation_17[0][0]              
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 32, 32, 256)  0           add_8[0][0]                      
__________________________________________________________________________________________________
zero_padding2d_20 (ZeroPadding2 (None, 34, 34, 256)  0           activation_19[0][0]              
__________________________________________________________________________________________________
res8_branch2a (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_20[0][0]          
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 32, 32, 256)  0           res8_branch2a[0][0]              
__________________________________________________________________________________________________
zero_padding2d_21 (ZeroPadding2 (None, 34, 34, 256)  0           activation_20[0][0]              
__________________________________________________________________________________________________
res8_branch2c (Conv2D)          (None, 32, 32, 256)  590080      zero_padding2d_21[0][0]          
__________________________________________________________________________________________________
add_9 (Add)                     (None, 32, 32, 256)  0           res8_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 32, 32, 256)  0           add_9[0][0]                      
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 64, 64, 256)  0           activation_21[0][0]              
__________________________________________________________________________________________________
zero_padding2d_22 (ZeroPadding2 (None, 66, 66, 256)  0           up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 128)  295040      zero_padding2d_22[0][0]          
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 64, 64, 128)  0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 128, 128, 128 0           activation_22[0][0]              
__________________________________________________________________________________________________
zero_padding2d_23 (ZeroPadding2 (None, 130, 130, 128 0           up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 128, 128, 64) 73792       zero_padding2d_23[0][0]          
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 128, 128, 64) 0           conv2d_15[0][0]                  
__________________________________________________________________________________________________
zero_padding2d_24 (ZeroPadding2 (None, 134, 134, 64) 0           activation_23[0][0]              
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 128, 128, 3)  9411        zero_padding2d_24[0][0]          
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 128, 128, 3)  0           conv2d_16[0][0]                  
==================================================================================================
Total params: 11,378,179
Trainable params: 11,378,179
Non-trainable params: 0
相关标签: 深度学习