欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 图像风格迁移

程序员文章站 2022-04-09 19:40:12
...

导入必要的包

import torch
from torch import nn
from torchvision import models
from torchvision import transforms
from PIL import Image, ImageFilter
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

定义图片内容的损失,用于约束生成图片的内容

class Content_loss(nn.Module):

    def __init__(self, target,):
        super(Content_loss, self).__init__()
        # we 'detach' the target content from the tree used
        # to dynamically compute the gradient: this is a stated value,
        # not a variable. Otherwise the forward method of the criterion
        # will throw an error.
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

格拉姆矩阵

def gram_matrix(input):
    a, b, c, d = input.size() 
    # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)

    G = torch.mm(features, features.t())  

    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d)

定义图片风格的损失,用于约束生成图片的风格

class Style_loss(nn.Module):
    def __init__(self, target_feature):
        super(Style_loss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

以 vgg19 作为特征提取器

vgg = models.vgg19(pretrained=True).features.to(device).eval()
# print(vgg)

content_layers_default = ['conv_4']
style_layers_default = ['conv_'+str(i) for i in range(1,6)]
# print(style_layers_default)

def get_style_model_and_loss(style_img, content_img):
    model = nn.Sequential().to(device)

    style_loss_list, content_loss_list = [], []
    i = 1
    for layer in vgg:
        if isinstance(layer, nn.Conv2d):
            name = 'conv_' + str(i)
            model.add_module(name, layer)

            if name in content_layers_default:
                target = model(content_img).detach()
                content_loss = Content_loss(target)
                model.add_module('content_loss_'+str(i), content_loss)
                content_loss_list.append(content_loss)

            if name in style_layers_default:
                target = model(style_img).detach()
                style_loss = Style_loss(target)
                model.add_module('style_loss_'+str(i), style_loss)
                style_loss_list.append(style_loss)

            i+=1

        if isinstance(layer, nn.MaxPool2d):
            name = 'pool_' + str(i)
            model.add_module(name, layer)

        if isinstance(layer, nn.ReLU):
            name = 'relu_' + str(i)
            layer = nn.ReLU(inplace=False)
            model.add_module(name, layer)
            
    return model, style_loss_list, content_loss_list

读取输入图片

style = Image.open('fg.jpg').resize([224,224])
content = Image.open('timg.jpg').resize([224,224])
plt.subplot(1,2,1)
plt.imshow(style)
plt.title('style')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(content)
plt.title('content')
plt.axis('off')
plt.show()
style_img = torch.from_numpy(np.array(style, np.float32, copy=False).transpose([2,0,1])/255).unsqueeze(0).to(device)
content_img = torch.from_numpy(np.array(content, np.float32, copy=False).transpose([2,0,1])/255).unsqueeze(0).to(device)

print(content_img.size(), style_img.size()) # torch.Size([1, 3, 224, 224]) torch.Size([1, 3, 224, 224])

pytorch 图像风格迁移
以 L-BFGS 为优化器,设置内容损失和风格损失的权重,1000000:1

def get_input_param_optimizer(input_img):
    input_param = nn.Parameter(input_img.data)
    optimizer = torch.optim.LBFGS([input_param])
    return input_param, optimizer

content_weight=1 
style_weight=1000000

def run_style_transfer(content_img, style_img, num_epoches=300):
    print('building the style transfer model ..')
    model, style_loss_list, content_loss_list = get_style_model_and_loss(style_img, 
                                                                         content_img)

    input_img = content_img.clone()
    input_param, optimizer = get_input_param_optimizer(input_img)
    
    print('optimizing...')
    epoch = [0]
    while epoch[0] < num_epoches:
        
        def closure():
            input_param.data.clamp_(0,1)
            model(input_param)
            
            style_score = 0
            content_score = 0
            
            optimizer.zero_grad()
            for sl in style_loss_list:
                style_score += sl.loss
            for cl in content_loss_list:
                content_score += cl.loss
                
            style_score *= style_weight
            content_score *= content_weight
            
            loss = style_score + content_score
            loss.backward()
            
            epoch[0] += 1

            if epoch[0] % 50 == 0:
                print('run {}'.format(epoch))
                print('Style Loss: {:.4f} Content Loss: {:.4f}'.format(style_score.item(), content_score.item()))
                print()
                
            return style_score + content_score
        
        optimizer.step(closure)
         
        input_param.data.clamp_(0,1)
        
        if epoch[0] % 50 == 0:
            plt.figure(figsize=(6,8))
            im = input_param.data.cpu().numpy()[0].transpose([1,2,0])
#             im = Image.fromarray(im.astype('uint8')).convert('RGB')
#             im = im.filter(ImageFilter.MedianFilter(3))
            plt.imshow(im)
            
            plt.axis('off')
            plt.show()
        
    return input_param.data
output = run_style_transfer(content_img, style_img, num_epoches=1000)

pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移
pytorch 图像风格迁移