pytorch 图像风格迁移
程序员文章站
2022-04-09 19:40:12
...
导入必要的包
import torch
from torch import nn
from torchvision import models
from torchvision import transforms
from PIL import Image, ImageFilter
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
定义图片内容的损失,用于约束生成图片的内容
class Content_loss(nn.Module):
def __init__(self, target,):
super(Content_loss, self).__init__()
# we 'detach' the target content from the tree used
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.target = target.detach()
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
格拉姆矩阵
def gram_matrix(input):
a, b, c, d = input.size()
# a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
定义图片风格的损失,用于约束生成图片的风格
class Style_loss(nn.Module):
def __init__(self, target_feature):
super(Style_loss, self).__init__()
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
以 vgg19 作为特征提取器
vgg = models.vgg19(pretrained=True).features.to(device).eval()
# print(vgg)
content_layers_default = ['conv_4']
style_layers_default = ['conv_'+str(i) for i in range(1,6)]
# print(style_layers_default)
def get_style_model_and_loss(style_img, content_img):
model = nn.Sequential().to(device)
style_loss_list, content_loss_list = [], []
i = 1
for layer in vgg:
if isinstance(layer, nn.Conv2d):
name = 'conv_' + str(i)
model.add_module(name, layer)
if name in content_layers_default:
target = model(content_img).detach()
content_loss = Content_loss(target)
model.add_module('content_loss_'+str(i), content_loss)
content_loss_list.append(content_loss)
if name in style_layers_default:
target = model(style_img).detach()
style_loss = Style_loss(target)
model.add_module('style_loss_'+str(i), style_loss)
style_loss_list.append(style_loss)
i+=1
if isinstance(layer, nn.MaxPool2d):
name = 'pool_' + str(i)
model.add_module(name, layer)
if isinstance(layer, nn.ReLU):
name = 'relu_' + str(i)
layer = nn.ReLU(inplace=False)
model.add_module(name, layer)
return model, style_loss_list, content_loss_list
读取输入图片
style = Image.open('fg.jpg').resize([224,224])
content = Image.open('timg.jpg').resize([224,224])
plt.subplot(1,2,1)
plt.imshow(style)
plt.title('style')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(content)
plt.title('content')
plt.axis('off')
plt.show()
style_img = torch.from_numpy(np.array(style, np.float32, copy=False).transpose([2,0,1])/255).unsqueeze(0).to(device)
content_img = torch.from_numpy(np.array(content, np.float32, copy=False).transpose([2,0,1])/255).unsqueeze(0).to(device)
print(content_img.size(), style_img.size()) # torch.Size([1, 3, 224, 224]) torch.Size([1, 3, 224, 224])
以 L-BFGS 为优化器,设置内容损失和风格损失的权重,1000000:1
def get_input_param_optimizer(input_img):
input_param = nn.Parameter(input_img.data)
optimizer = torch.optim.LBFGS([input_param])
return input_param, optimizer
content_weight=1
style_weight=1000000
def run_style_transfer(content_img, style_img, num_epoches=300):
print('building the style transfer model ..')
model, style_loss_list, content_loss_list = get_style_model_and_loss(style_img,
content_img)
input_img = content_img.clone()
input_param, optimizer = get_input_param_optimizer(input_img)
print('optimizing...')
epoch = [0]
while epoch[0] < num_epoches:
def closure():
input_param.data.clamp_(0,1)
model(input_param)
style_score = 0
content_score = 0
optimizer.zero_grad()
for sl in style_loss_list:
style_score += sl.loss
for cl in content_loss_list:
content_score += cl.loss
style_score *= style_weight
content_score *= content_weight
loss = style_score + content_score
loss.backward()
epoch[0] += 1
if epoch[0] % 50 == 0:
print('run {}'.format(epoch))
print('Style Loss: {:.4f} Content Loss: {:.4f}'.format(style_score.item(), content_score.item()))
print()
return style_score + content_score
optimizer.step(closure)
input_param.data.clamp_(0,1)
if epoch[0] % 50 == 0:
plt.figure(figsize=(6,8))
im = input_param.data.cpu().numpy()[0].transpose([1,2,0])
# im = Image.fromarray(im.astype('uint8')).convert('RGB')
# im = im.filter(ImageFilter.MedianFilter(3))
plt.imshow(im)
plt.axis('off')
plt.show()
return input_param.data
output = run_style_transfer(content_img, style_img, num_epoches=1000)
下一篇: 图像风格迁移