This blog walks through the intuition behind the neural style transfer and its implementation.
This post assumes that you have basic skills of working with PyTorch. If you are new to PyTorch, I would highly encourage you to go through Deep Leaning With PyTorch: A 60 Minute Blitz by PyTorch. It’s a great place for beginners to get your hands dirty.
Neural Style Transfer is an algorithm developed by Leon A. Gatys, Alexander S. Ecker, and Matthias Bethge that blends the content of one image with the style of another image using Deep Neural Networks to create artistic images of high perceptual quality.
Convolutional Neural Networks are very powerful, extracting the visual information hierarchically. This makes them really useful for this task. The lower layers care more about the detailed pixel values, whereas the higher layers care more about the actual content of the image (objects such as eyes, nose, etc.).
In the above figure, the output image is a mix of two since we use the activations of the neural network at specific layers as a filter to get the intermediate style and content output of the inputs.
The principle underlying the neural style transfer is simple:
where,
where,
which is a weighted sum of the two above:
\[L_{total}(\bar{p}, \bar{a}, \bar{x}) = \alpha L_{content}(\bar{p}, \bar{x}) + \beta L_{style}(\bar{a}, \bar{x})\]where,
There is a trade-off between the actual content and artistic style, which is determined by $\alpha$ and $\beta$. If the content is more important, then increase the $\alpha$. If the style is more important, then increase the $\beta$.
Now, let’s go to the implementation of the above algorithm by importing the below packages.
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
Now, we select the device on which our network will run. Neural style transfer algorithm runs faster on GPU so check if GPU is available using torch.cuda.is_available()
.
# device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Now, let’s download and load the pre-trained VGG19 model. VGG is trained for the task of object detection. We freeze all VGG parameters as we are using it for optimizing the target image.
# load the pretrained VGG19
vgg = models.vgg19(pretrained=True).features
# move the vgg model to GPU in eval mode (freeze model parameters) if available
vgg.to(device).eval()
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace=True)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace=True)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace=True)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace=True)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace=True)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace=True)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Now, let’s load the style and content images. The PIL images loaded have values between 0 to 255, but when they are transformed into torch tensors, their values are converted between 0 and 1. We perform few transformations such as Resize()
, ToTensor()
, Normalize()
on the image.
# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
# VGG19 mean and std for each channel
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
loader = transforms.Compose([
# scale imported image
transforms.Resize(imsize),
# transform it into a torch tensor
transforms.ToTensor(),
# normalize the tensor as per VGG network
transforms.Normalize(mean=cnn_normalization_mean,
std=cnn_normalization_std)
])
def load_image(image_path, transform=None):
"""
Load an image and comvert it to a torch tensor.
"""
image = Image.open(image_path)
if transform:
# transform the image
image = transform(image)
# add a fake batch dimension to fit network's input dimension
image = image.unsqueeze(0)
return image
# load the style and content image
style_img = load_image('./images/style.jpg', transform=loader).to(device)
content_img = load_image('./images/content.jpg', transform=loader).to(device)
Now, let’s create a function to denormalize the image tensors, which will be later helpful to display the image tensors.
def denorm_image(img_tensor):
"""
Denormalize the image for visualization
"""
# clone the image tensor and detach from tracking
image = img_tensor.to('cpu').clone().detach()
# remove the fake batch dimension
image = image.numpy().squeeze()
# reshape (n_C, n_H, n_W) -> (n_H, n_W, n_C)
image = image.transpose(1, 2, 0)
# denormalize the image
image = image * cnn_normalization_std.numpy() + cnn_normalization_mean.numpy()
# restrict the value between 0 and 1 by clipping the outliers
image = image.clip(0, 1)
return image
Now, let’s display the content and style image.
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
ax1.imshow(denorm_image(content_img))
ax1.set_title('Content', fontsize=20)
ax2.imshow(denorm_image(style_img))
ax2.set_title('Style', fontsize=20)
plt.show()
Now, let’s select the convolutional layers from VGG19 to extract the feature maps.
def get_feature_maps(image, model, layers=None):
"""
Extract the convolutional feature maps from conv1_1 ~ conv5_1
"""
if layers is None:
# layer number for conv1_1 ~ conv5_1
layers = ['0', '5', '10', '19', '28']
# conv feature map
features = []
x = image
# iterate through the model layers
for name, layer in model._modules.items():
x = layer(x)
# checks for the layer match
if name in layers:
features.append(x)
return features
Now, let’s define the gram matrix used in style loss, which we try to minimize during the backpropagation.
def gram_matrix(input):
"""
Calculates the gram matrix for input
"""
# a = 1 (batch_size), b = n_C (number of feature maps)
# (c, d) = dimension of feature map
a, b, c, d = input.size()
# reshape the convolutional feature maps
features = input.view(a * b, c * d)
# compute the gram matrix
G = torch.mm(features, features.t())
# Normalize the values of gram matrix
G = G.div(a * b * c * d)
return G
Now, let’s create a clone of the content image as a starting image for the target, which we transform such that the content image has an artistic style.
target_img = content_img.clone().to(device)
# Alternative way: you can start with white noise to get an image with
# content attributes of content image and style attributes of style image
# target_img = torch.randn(content_img.data.size()).to(device)
Now let’s run the model and try to minimize the loss using backpropagation.
def run_neural_style_transfer(model, content_img, style_img, target_img,
num_steps=2000, sample_steps=400, learning_rate=0.02,
style_weight=1e4, content_weight=1e-2):
"""
Run the neural style transfer
"""
# optimizer for reconstruction of content image with artistic style
optimizer = torch.optim.Adam([target_img.requires_grad_()], lr=learning_rate,
betas=[0.99, 0.999])
for step in range(num_steps):
# extract the conv feature maps for target, content and style images
target_features = get_feature_maps(target_img, model)
content_features = get_feature_maps(content_img, model)
style_features = get_feature_maps(style_img, model)
# initialize the style and content loss
style_loss = 0
content_loss = 0
# calculate the style and content loss for each specific layer
for f1, f2, f3 in zip(target_features, content_features, style_features):
# compute content loss with target and content images
content_loss += torch.mean((f1 - f2) ** 2)
# compute the gram matrix for target and style feature maps
f1 = gram_matrix(f1)
f3 = gram_matrix(f3)
# compute the style loss with target and style images
style_loss += torch.mean((f1 - f3) ** 2)
# compute total loss, backprop and optimize
loss = content_loss * content_weight + style_loss * style_weight
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1) % sample_steps == 0:
# print the model stats
print("run {}:".format(step+1))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_loss.item(), content_loss.item()))
print()
return target_img
# run the neural style transfer
output_img = run_neural_style_transfer(vgg, content_img, style_img, target_img)
# display the style transfered image
output_img = denorm_image(output_img)
plt.figure()
plt.imshow(output_img)
plt.title('Output Image')
plt.show()
run 400:
Style Loss : 0.000006 Content Loss: 21.730204
run 800:
Style Loss : 0.000004 Content Loss: 19.803923
run 1200:
Style Loss : 0.000004 Content Loss: 19.264122
run 1600:
Style Loss : 0.000004 Content Loss: 19.008064
run 2000:
Style Loss : 0.000004 Content Loss: 18.843626
Great! Now you have become an artist who can generate artworks from a content image and style image.