欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

风格迁移StyleTransfer和Pytorch实现

程序员文章站 2024-03-21 08:07:22
...

风格迁移及Pytorch实现

风格迁移,就是利用算法学习一幅画的风格,然后再把这种风格应用到另外一张图片上。

本篇文章会介绍其原理,并使用Pytorch实现。

风格迁移StyleTransfer和Pytorch实现

在卷积中,浅层特征越具体,深层特征则越抽象);从风格角度来说,浅层特征则记录着颜色纹理等信息,而深层特征则会记录更高级的信息。

主要方式则是,随机一张图片,通过优化内容损失和风格损失,改变该图,使其内容接近内容图片,风格上接近风格图片。

内容损失:直接计算特征图的欧式距离

风格损失:计算特征图的格拉姆矩阵的欧式距离

格拉姆矩阵的计算方式:

def get_gram_matrix(f_map):
    n, c, h, w = f_map.shape
    if n == 1:
        f_map = f_map.reshape(c, h * w)
        gram_matrix = torch.mm(f_map, f_map.t())
        return gram_matrix
    else:
        raise ValueError('批次应该为1,但是传入的不为1')

将特征图reshape,将宽高的维度合在一起,然后计算其与自身转置的矩阵乘法即可。

迁移出预先训练好的VGG19的模型。并输出五个不同维度的特征图。

from torchvision.models import vgg19
from torch import nn
from torchvision.utils import save_image
import torch
import cv2


class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        a = vgg19(True)
        a = a.features
        self.layer1 = a[:4]
        self.layer2 = a[4:9]
        self.layer3 = a[9:18]
        self.layer4 = a[18:27]
        self.layer5 = a[27:36]

    def forward(self, input_):
        out1 = self.layer1(input_)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
        out5 = self.layer5(out4)
        return out1, out2, out3, out4, out5

将图片直接定义为网络参数,来训练它。这里直接从原始内容图训练,也可以使用白噪声。

class GNet(nn.Module):
    def __init__(self, image):
        super(GNet, self).__init__()
        self.image_g = nn.Parameter(image.detach().clone())
        # self.image_g = nn.Parameter(torch.rand(image.shape))  # 也可以初始化一张白噪声训练 

    def forward(self):
        return self.image_g.clamp(0, 1)  # 为了限定数值范围。

定义加载图片函数:

def load_image(path):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = torch.from_numpy(image).float() / 255
    image = image.permute(2, 0, 1).unsqueeze(0)
    return image

需要使用图片需要保持形状一致

首先加载内容图片风格图片,再实例化VGG19网络图片,图片直接从原内容图开始训练。

实例化优化器损失函数

image_content = load_image('c.jpg').cuda()
image_style = load_image('s.jpg').cuda()
net = VGG19().cuda()
g_net = GNet(image_content).cuda()
optimizer = torch.optim.Adam(g_net.parameters())
loss_func = nn.MSELoss().cuda()

计算风格图片的输入VGG19的输出,并得到其格拉姆矩阵

s1, s2, s3, s4, s5 = net(image_style)
s1 = get_gram_matrix(s1).detach().clone()
s2 = get_gram_matrix(s2).detach().clone()
s3 = get_gram_matrix(s3).detach().clone()
s4 = get_gram_matrix(s4).detach().clone()
s5 = get_gram_matrix(s5).detach().clone()

计算内容图片输入VGG19的输出

c1, c2, c3, c4, c5 = net(image_content)
c1 = c1.detach().clone()
c2 = c2.detach().clone()
c3 = c3.detach().clone()
c4 = c4.detach().clone()
c5 = c5.detach().clone()

训练该图片。

i = 0
while True:
    """生成图片,计算损失"""
    image = g_net()
    out1, out2, out3, out4, out5 = net(image)

    """计算分格损失"""
    loss_s1 = loss_func(get_gram_matrix(out1), s1)
    loss_s2 = loss_func(get_gram_matrix(out2), s2)
    loss_s3 = loss_func(get_gram_matrix(out3), s3)
    loss_s4 = loss_func(get_gram_matrix(out4), s4)
    loss_s5 = loss_func(get_gram_matrix(out5), s5)
    loss_s = 0.1*loss_s1 + 0.1*loss_s2 + 0.6*loss_s3 + 0.1*loss_s4 + 0.1*loss_s5

    """计算内容损失"""
    loss_c1 = loss_func(out1, c1)
    loss_c2 = loss_func(out2, c2)
    loss_c3 = loss_func(out3, c3)
    loss_c4 = loss_func(out4, c4)
    loss_c5 = loss_func(out5, c5)
    loss_c = 0.05 * loss_c1 + 0.05 * loss_c2 + 0.15 * loss_c3 + 0.3 * loss_c4 + 0.45 * loss_c5

    """总损失"""
    loss = 0.5*loss_c + 0.5*loss_s

    """清空梯度、计算梯度、更新参数"""
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(i, loss.item(), loss_c.item(), loss_s.item())
    if i % 1000 == 0:
        save_image(image, f'{i}.jpg', padding=0, normalize=True, range=(0, 1))
    i += 1

分别计算风格损失和内容损失,然后求得总损失,优化该损失。

基本迭代一千次即可出效果。

内容图片为:

风格迁移StyleTransfer和Pytorch实现

几个图片的效果展示:

风格图片 生成图片
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现
/ 风格迁移StyleTransfer和Pytorch实现
/ 风格迁移StyleTransfer和Pytorch实现
/ 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现
/ 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 /> 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现
/ 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现
风格迁移StyleTransfer和Pytorch实现 风格迁移StyleTransfer和Pytorch实现

调整各个损失不同的比例系数,能够达到不同的效果。可酌情尝试。