欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

图像风格迁移

程序员文章站 2022-04-09 19:40:06
...

一.图像迁移

这里我们需要两张输入图像,一张是内容图像,另一张是样式图像,我们将使用神经网络修改内容图像使其在样式上接近样式图像。
图像风格迁移

1.1 方法

假设.如果选取的预训练的神经网络含有3个卷积层,其中第二层输出图像的内容特征,而第一层和第三层的输出被作为图像的样式特征。接下来,我们通过**正向传播(实线箭头方向)计算样式迁移的损失函数,并通过反向传播(虚线箭头方向)**迭代模型参数,即不断更新合成图像。样式迁移常用的损失函数由3部分组成: 内容损失(content loss) 使合成图像与内容图像在内容特征上接近,样式损失(style loss) 令合成图像与样式图像在样式特征上接近,而 总变差损失(total variation loss) 则有助于减少合成图像中的噪点。
图像风格迁移

1.2 实践

1.2.1 抽取特征

VGG-19模型来抽取图像特征
图像风格迁移

pretrained_net = torchvision.models.vgg19(pretrained=False)

为了抽取图像的内容特征样式特征,我们可以选择VGG网络中某些层的输出。一般来说,越靠近输入层的输出越容易抽取图像的细节信息,反之则越容易抽取图像的全局信息。为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,也称内容层,来输出图像的内容特征。我们还从VGG中选择不同层的输出来匹配局部和全局的样式,这些层也叫样式层

style_layers, content_layers = [0, 5, 10, 19, 28], [25]
def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

1.2.2 定义损失函数

//内容损失
def content_loss(Y_hat, Y):
    return F.mse_loss(Y_hat, Y)
//总变差损失
def tv_loss(Y_hat):
    return 0.5 * (F.l1_loss(Y_hat[:, :, 1:, :], Y_hat[:, :, :-1, :]) + 
                  F.l1_loss(Y_hat[:, :, :, 1:], Y_hat[:, :, :, :-1]))

样式损失
假设该输出的样本数为1,通道数为 c ,高和宽分别为 h 和 w ,我们可以把输出变换成 c 行 hw 列的矩阵 X 。矩阵 X 可以看作是由 c 个长度为 hw 的向量 x1,…,xc 组成的。其中向量 xi 代表了通道 ii 上的样式特征。这些向量的格拉姆矩阵(Gram matrix) XX⊤∈Rc×c 中 i 行 j 列的元素 xij 即向量 xi 与 xj 的内积,它表达了通道 i 和通道 j 上样式特征的相关性。我们用这样的格拉姆矩阵表达样式层输出的样式

def gram(X):
    num_channels, n = X.shape[1], X.shape[2] * X.shape[3]
    X = X.view(num_channels, n)
    return torch.matmul(X, X.t()) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
    return F.mse_loss(gram(Y_hat), gram_Y)

损失函数
样式迁移的损失函数即内容损失、样式损失和总变差损失的加权和。通过调节这些权值超参数,我们可以权衡合成图像在保留内容、迁移样式以及降噪三方面的相对重要性。

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、样式损失和总变差损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(styles_l) + sum(contents_l) + tv_l
    return contents_l, styles_l, tv_l, l

1.2.3 训练

def train(X, contents_Y, styles_Y, device, lr, max_epochs, lr_decay_epoch):
    print("training on ", device)
    X, styles_Y_gram, optimizer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, gamma=0.1)
    for i in range(max_epochs):
        start = time.time()
        
        contents_Y_hat, styles_Y_hat = extract_features(
                X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
                X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        
        optimizer.zero_grad()
        l.backward(retain_graph = True)
        optimizer.step()
        scheduler.step()
        
        if i % 50 == 0 and i != 0:
            print('epoch %3d, content loss %.2f, style loss %.2f, '
                  'TV loss %.2f, %.2f sec'
                  % (i, sum(contents_l).item(), sum(styles_l).item(), tv_l.item(),
                     time.time() - start))
    return X.detach()
相关标签: 动手学深度学习