欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 实现 Style Transfer

程序员文章站 2022-07-14 20:26:31
...

pytorch 实现 Style Transfer


设CNN中第ll 层风格图片、内容图片、生成图片的feature map分别为SlS^lClC^lGlG^l

ll层内容损失函数定义为 LCl=ij(CijlGijl)2L_C^l=\sum\limits_{ij}(C^l_{ij}-G^l_{ij})^2

对于 feature map FF,定义F(kk)F^{(kk')} 为第kk通道和第kk'通道 feature map 的内积,则第ll层风格损失函数定义为 LSl=1(nHnWnC)2kk(Sl(kk)Gl(kk))2L_S^l=\frac{1}{(n_Hn_Wn_C)^2}\sum\limits_{kk'}(S^{l(kk')}-G^{l(kk')})^2

CNN框架采用VGG19,梯度下降采用L-BFGS。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 512 if torch.cuda.is_available() else 128

transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(imsize),
    transforms.ToTensor()])

def image_loader(image_name):
    image = Image.open(image_name)
    image = image.convert('RGB')
    image = transform(image).unsqueeze(0)
    return image.to(device, torch.float)

style_img = image_loader("./data/images/the_starry_night.jpg")
content_img = image_loader("./data/images/STJU.jpg")
assert style_img.size() == content_img.size()

def trans(tensor):
    image = tensor.cpu().clone() 
    image = image.squeeze(0)     
    image = image.numpy()
    image = image.transpose((1, 2, 0))
    return image

class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()
    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

def gram_matrix(input):
    a, b, c, d = input.size()

    features = input.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b * c * d)

class StyleLoss(nn.Module):
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

cnn = models.vgg19(pretrained=True).features.to(device).eval()

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.view(-1, 1, 1)
        self.std = std.view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img, content_layers=content_layers_default, style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)
    normalization = Normalization(normalization_mean, normalization_std).to(device)

    content_losses = []
    style_losses = []

    model = nn.Sequential(normalization)

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecogniced layer: {}'.format(layer/__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)
    
    for i in range(len(model)-1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break
    
    model = model[:(i + 1)]

    return model, style_losses, content_losses

input_img = content_img.clone()


def get_input_optimizer(input_img):
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer

def run_style_tranfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, num_steps=300, style_weight=10000000000, content_weight=1):
    print('Building the style transfer model..')

    model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img)

    optimizer = get_input_optimizer(input_img)

    print('Optimizing..')
    run = [0]
    while run[0] < num_steps:
        
        def closure():
            input_img.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0.0
            content_score = 0.0

            for sl in style_losses:
                style_score +=sl.loss
            for cl in content_losses:
                content_score +=cl.loss
            
            style_score *= style_weight
            content_score *= content_weight
            
            loss = style_score + content_score
            loss.backward()

            run[0] += 1
            print(run[0])
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score
        
        optimizer.step(closure)

    input_img.data.clamp_(0, 1)

    return input_img

output = run_style_tranfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img)

style_img = trans(style_img)
content_img = trans(content_img)
output = trans(output.detach())

f, ax = plt.subplots(1, 3)
ax[0].imshow(style_img)
ax[0].set_title('Style Image')
ax[1].imshow(content_img)
ax[1].set_title('Content Image')
ax[2].imshow(output)
ax[2].set_title('Transposed Image')
plt.show()

迭代300次效果如下:
pytorch 实现 Style Transfer