cs231n_2017_Style_Transfer
程序员文章站
2024-01-01 13:27:04
...
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as T
import PIL
import numpy as np
from scipy.misc import imread
from collections import namedtuple
import matplotlib.pyplot as plt
from cs231n.image_utils import SQUEEZENET_MEAN, SQUEEZENET_STD
%matplotlib inline
#image_utils.py:
from __future__ import print_function
from builtins import range
import urllib.request, urllib.error, urllib.parse, os, tempfile
import numpy as np
from scipy.misc import imread, imresize
"""
Utility functions used for viewing and processing images.
"""
def blur_image(X):
"""
A very gentle image blurring operation, to be used as a regularizer for
image generation.
Inputs:
- X: Image data of shape (N, 3, H, W)
Returns:
- X_blur: Blurred version of X, of shape (N, 3, H, W)
"""
from cs231n.fast_layers import conv_forward_fast
w_blur = np.zeros((3, 3, 3, 3))
b_blur = np.zeros(3)
blur_param = {'stride': 1, 'pad': 1}
for i in range(3):
w_blur[i, i] = np.asarray([[1, 2, 1], [2, 188, 2], [1, 2, 1]],
dtype=np.float32)
w_blur /= 200.0
return conv_forward_fast(X, w_blur, b_blur, blur_param)[0]
SQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def preprocess_image(img):
"""Preprocess an image for squeezenet.
Subtracts the pixel mean and divides by the standard deviation.
"""
return (img.astype(np.float32)/255.0 - SQUEEZENET_MEAN) / SQUEEZENET_STD
def deprocess_image(img, rescale=False):
"""Undo preprocessing on an image and convert back to uint8."""
img = (img * SQUEEZENET_STD + SQUEEZENET_MEAN)
if rescale:
vmin, vmax = img.min(), img.max()
img = (img - vmin) / (vmax - vmin)
return np.clip(255 * img, 0.0, 255.0).astype(np.uint8)
def image_from_url(url):
"""
Read an image from a URL. Returns a numpy array with the pixel data.
We write the image to a temporary file then read it back. Kinda gross.
"""
try:
f = urllib.request.urlopen(url)
_, fname = tempfile.mkstemp()
with open(fname, 'wb') as ff:
ff.write(f.read())
img = imread(fname)
os.remove(fname)
return img
except urllib.error.URLError as e:
print('URL Error: ', e.reason, url)
except urllib.error.HTTPError as e:
print('HTTP Error: ', e.code, url)
def load_image(filename, size=None):
"""Load and resize an image from disk.
Inputs:
- filename: path to file
- size: size of shortest dimension after rescaling
"""
img = imread(filename)
if size is not None:
orig_shape = np.array(img.shape[:2])
min_idx = np.argmin(orig_shape)
scale_factor = float(size) / orig_shape[min_idx]
new_shape = (orig_shape * scale_factor).astype(int)
img = imresize(img, scale_factor)
return img
def preprocess(img, size=512):
transform = T.Compose([
T.Scale(size),
T.ToTensor(),
T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
std=SQUEEZENET_STD.tolist()),
T.Lambda(lambda x: x[None]),
])
return transform(img)
def deprocess(img):
transform = T.Compose([
T.Lambda(lambda x: x[0]),
T.Normalize(mean=[0, 0, 0], std=[1.0 / s for s in SQUEEZENET_STD.tolist()]),
T.Normalize(mean=[-m for m in SQUEEZENET_MEAN.tolist()], std=[1, 1, 1]),
T.Lambda(rescale),
T.ToPILImage(),
])
return transform(img)
def rescale(x):
low, high = x.min(), x.max()
x_rescaled = (x - low) / (high - low)
return x_rescaled
def rel_error(x,y):
return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))
def features_from_img(imgpath, imgsize):
img = preprocess(PIL.Image.open(imgpath), size=imgsize)
img_var = Variable(img.type(dtype))
return extract_features(img_var, cnn), img_var
# Older versions of scipy.misc.imresize yield different results
# from newer versions, so we check to make sure scipy is up to date.
def check_scipy():
import scipy
vnum = int(scipy.__version__.split('.')[1])
assert vnum >= 16, "You must install SciPy >= 0.16.0 to complete this notebook."
#check_scipy()
answers = np.load('style-transfer-checks.npz')
dtype = torch.FloatTensor
# Uncomment out the following line if you're on a machine with a GPU set up for PyTorch!
# dtype = torch.cuda.FloatTensor
# Load the pre-trained SqueezeNet model.
cnn = torchvision.models.squeezenet1_1(pretrained=True).features
cnn.type(dtype)
# We don't want to train the model any further, so we don't want PyTorch to waste computation
# computing gradients on parameters we're never going to update.
for param in cnn.parameters():
param.requires_grad = False
# We provide this helper code which takes an image, a model (cnn), and returns a list of
# feature maps, one per layer.
def extract_features(x, cnn):
"""
Use the CNN to extract features from the input image x.
Inputs:
- x: A PyTorch Variable of shape (N, C, H, W) holding a minibatch of images that
will be fed to the CNN.
- cnn: A PyTorch model that we will use to extract features.
Returns:
- features: A list of feature for the input images x extracted using the cnn model.
features[i] is a PyTorch Variable of shape (N, C_i, H_i, W_i); recall that features
from different layers of the network may have different numbers of channels (C_i) and
spatial dimensions (H_i, W_i).
"""
features = []
prev_feat = x
for i, module in enumerate(cnn._modules.values()):
next_feat = module(prev_feat)
features.append(next_feat)
prev_feat = next_feat
return features
def content_loss(content_weight, content_current, content_original):
"""
Compute the content loss for style transfer.
Inputs:
- content_weight: Scalar giving the weighting for the content loss.
- content_current: features of the current image; this is a PyTorch Tensor of shape
(1, C_l, H_l, W_l).
- content_target: features of the content image, Tensor with shape (1, C_l, H_l, W_l).
Returns:
- scalar content loss
"""
N, C_l, H_l, W_l = content_current.size()
F_l = content_current.view(C_l, H_l*W_l)
P_l = content_original.view(C_l, H_l*W_l)
loss = content_weight * (torch.sum((F_l - P_l)**2))
return loss
def gram_matrix(features, normalize=True):
"""
Compute the Gram matrix from features.
Inputs:
- features: PyTorch Variable of shape (N, C, H, W) giving features for
a batch of N images.
- normalize: optional, whether to normalize the Gram matrix
If True, divide the Gram matrix by the number of neurons (H * W * C)
Returns:
- gram: PyTorch Variable of shape (N, C, C) giving the
(optionally normalized) Gram matrices for the N input images.
"""
N, C, H, W = features.size()
# Reshape feature map.
F_l = features.view(1, C, H*W)
# Gram calculation is just a matrix multiply.
gram = torch.mm(F_l[0,:,:], F_l[0,:,:].transpose(1,0))
if normalize == True:
gram /= (H*W*C)
# Add back first dimension
gram = gram.unsqueeze(0)
return gram
# Now put it together in the style_loss function...
def style_loss(feats, style_layers, style_targets, style_weights):
"""
Computes the style loss at a set of layers.
Inputs:
- feats: list of the features at every layer of the current image, as produced by
the extract_features function.
- style_layers: List of layer indices into feats giving the layers to include in the
style loss.
- style_targets: List of the same length as style_layers, where style_targets[i] is
a PyTorch Variable giving the Gram matrix the source style image computed at
layer style_layers[i].
- style_weights: List of the same length as style_layers, where style_weights[i]
is a scalar giving the weight for the style loss at layer style_layers[i].
Returns:
- style_loss: A PyTorch Variable holding a scalar giving the style loss.
"""
# Hint: you can do this with one for loop over the style layers, and should
# not be very much code (~5 lines). You will need to use your gram_matrix function.
style_loss = Variable(torch.zeros(1))
i = 0
# Compute style loss for each desired feature layer and sum.
for layer in style_layers:
current_im_gram = gram_matrix(feats[layer])
style_loss += style_weights[i] * torch.sum((current_im_gram - style_targets[i])**2)
i+=1
return style_loss
def tv_loss(img, tv_weight):
"""
Compute total variation loss.
Inputs:
- img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
- tv_weight: Scalar giving the weight w_t to use for the TV loss.
Returns:
- loss: PyTorch Variable holding a scalar giving the total variation loss
for img weighted by tv_weight.
"""
# Your implementation should be vectorized and not require any loops!
w_variance = torch.sum((img[:,:,:,1:] - img[:,:,:,:-1])**2)
h_variance = torch.sum((img[:,:,1:,:] - img[:,:,:-1,:])**2)
loss = tv_weight * (w_variance + h_variance)
return loss
def style_transfer(content_image, style_image, image_size, style_size, content_layer, content_weight,
style_layers, style_weights, tv_weight, init_random = False):
"""
Run style transfer!
Inputs:
- content_image: filename of content image
- style_image: filename of style image
- image_size: size of smallest image dimension (used for content loss and generated image)
- style_size: size of smallest style image dimension
- content_layer: layer to use for content loss
- content_weight: weighting on content loss
- style_layers: list of layers to use for style loss
- style_weights: list of weights to use for each layer in style_layers
- tv_weight: weight of total variation regularization term
- init_random: initialize the starting image to uniform random noise
"""
# Extract features for the content image
content_img = preprocess(PIL.Image.open(content_image), size=image_size)
content_img_var = Variable(content_img.type(dtype))
feats = extract_features(content_img_var, cnn)
content_target = feats[content_layer].clone()
# Extract features for the style image
style_img = preprocess(PIL.Image.open(style_image), size=style_size)
style_img_var = Variable(style_img.type(dtype))
feats = extract_features(style_img_var, cnn)
style_targets = []
for idx in style_layers:
style_targets.append(gram_matrix(feats[idx].clone()))
# Initialize output image to content image or nois
if init_random:
img = torch.Tensor(content_img.size()).uniform_(0, 1)
else:
img = content_img.clone().type(dtype)
# We do want the gradient computed on our image!
img_var = Variable(img, requires_grad=True)
# Set up optimization hyperparameters
initial_lr = 3.0
decayed_lr = 0.1
decay_lr_at = 180
# Note that we are optimizing the pixel values of the image by passing
# in the img_var Torch variable, whose requires_grad flag is set to True
optimizer = torch.optim.Adam([img_var], lr=initial_lr)
f, axarr = plt.subplots(1,2)
axarr[0].axis('off')
axarr[1].axis('off')
axarr[0].set_title('Content Source Img.')
axarr[1].set_title('Style Source Img.')
axarr[0].imshow(deprocess(content_img.cpu()))
axarr[1].imshow(deprocess(style_img.cpu()))
plt.show()
plt.figure()
for t in range(200):
if t < 190:
img.clamp_(-1.5, 1.5)
optimizer.zero_grad()
feats = extract_features(img_var, cnn)
# Compute loss
c_loss = content_loss(content_weight, feats[content_layer], content_target)
s_loss = style_loss(feats, style_layers, style_targets, style_weights)
t_loss = tv_loss(img_var, tv_weight)
loss = c_loss + s_loss + t_loss
loss.backward()
# Perform gradient descents on our image values
if t == decay_lr_at:
optimizer = torch.optim.Adam([img_var], lr=decayed_lr)
optimizer.step()
if t % 100 == 0:
print('Iteration {}'.format(t))
plt.axis('off')
plt.imshow(deprocess(img.cpu()))
plt.show()
print('Iteration {}'.format(t))
plt.axis('off')
img_processed = deprocess(img.cpu())
plt.imshow(img_processed)
plt.savefig('1-6.jpg')
plt.show()
# Composition VII + Tubingen
params1 = {
'content_image' : 'styles/tubingen.jpg',
'style_image' : 'styles/composition_vii.jpg',
'image_size' : 192,
'style_size' : 512,
'content_layer' : 3,
'content_weight' : 5e-2,
'style_layers' : (1, 4, 6, 7),
'style_weights' : (20000, 500, 12, 1),
'tv_weight' : 5e-2
}
style_transfer(**params1)
# Feature Inversion -- Starry Night + Tubingen
params_inv = {
'content_image' : 'styles/tubingen.jpg',
'style_image' : 'styles/starry_night.jpg',
'image_size' : 192,
'style_size' : 192,
'content_layer' : 3,
'content_weight' : 6e-2,
'style_layers' : [1, 4, 6, 7],
'style_weights' : [0, 0, 0, 0], # we discard any contributions from style to the loss
'tv_weight' : 2e-2,
'init_random': True # we want to initialize our image to be random
}
style_transfer(**params_inv)