Skip to content

Fast Neural Style Transfer: training the model

Reading Time: 11 minutes

Note: code available on Github here. If you are not familiar with the Gatys et al implementation of Neural Style Transfer you can read this post of mine.

In my previous post, I have described the SageMaker deployment process of PyTorch models to perform Real-Time Style Transfer. I had entirely focused that one on the infrastructural side of the project, and I had promised I would shortly publish a write-up around the Deep Learning magic happening behind the scenes. This post delivers on my promises, providing a deep dive into the technical implementation of Fast Neural Style Transfer, as per Johnson & al implementation. 

Context

The big issue with the standard Style Transfer algorithm proposed by Gatys et al is that it involves running a separate optimization loop each time we want to stylize a new picture. The process consists of picking a content and a style image and then creating a brand new one, iteratively minimizing a loss function on the spot. Even on a GPU, this is quite time-consuming. On top of it, it looks a little cumbersome. Van Gogh’s Starry Night’s style clearly doesn’t change if applied on two different content images. So, if you think about it, we are basically performing the same calculations over and over again. It would be probably wiser to learn Van Gogh’s Starry Night’s style only once, and then stylize new pictures by just running them through the model.  This strategy is what Johnson et al proposed in their 2016 Perceptual Losses for Real-Time Style Transfer and Super-Resolution paper.

So, to recap, the idea is to train a neural network (NN) to learn styles.

If we look at this from the Gatys et al perspective, the difference with the traditional approach is that, instead of initializing the stylized image with random noise, we use the output of a model. This means that, if in the standard approach, during the optimization loop, we tweak input’s pixels’ values, with this new strategy we task SGD to tune NN’s weights.

The visualization below shows this. On the right side of the vertical dashed line, there is traditional style transfer, e.g. grab content and style image, extract features, compare with same features from input picture and minimize some metrics between the three. On the left side of the dashed line, instead, there is the network we want to train. It takes the content image as input (yes! the same content image we are optimizing against; more clarity on this point later on). An output image comes out of the model and gets fed into the loss function on the right.  The NN’s weights are tuned to minimize the loss.

Another helpful visualization (with original caption) comes directly from the Johnson et al paper. As you can see below, the authors show the pipeline very clearly. An input image gets ingested into a NN (Transform Net) and the output is compared to style and content in the standard-Gatys-et-al loss function.

Let’s deep dive and figure out what is going on.

The data

Every deep learning problem can be broken into three parts:

  • the data
  • the architecture
  • the loss function

This one makes no exception. 

In a standard supervised ML challenge, we always have {X, y} pairs. Each X data point is accompanied by a y label, whose nature changes according to the problem at hand (binary, multi-class, multi-label/multi-class classification, regression etc). A typical ML exercise consists in building a model which gets X as input and generates y_hat as output, where the quality of `y_hat` is assessed by comparison with the ground truth y.

Our data structure is not as trivial as that, though. First of all, we don’t have a clear definition of ground truth. How are we going to establish if the output of the NN is correct? The reality is that we can’t do that, as there is not an exact unique response. An image could be rendered in Picasso’s Guernica’s style in many possible ways, all legit. What we want is the loss to decrease and, maybe most importantly, the results to be visually appealing. So, what is the best way to structure our dataset? Technically, each data point shipped to the network needs to be composed of three elements:

  1. content image: this is where we will get the content from, e.g. a shot of the New York skyline. The content loss will measure how closely the stylized image resembles the content image in terms of what the picture shows. In terms of which pictures I actually used to train my models, I opted for the COCO dataset, specifically the 2015 Test images [81K/12GB].
  2. style image: this is where we will get the style from, e.g. Van Gogh’s Starry Night or Picasso’s Guernica. The style loss will measure how closely the stylized image manages to reproduce the artistic connotations of the painting or, in other words, the style of the painter.
  3. to-be-stylized-image (aka input): this is the image which will get stylized, combining content and style from #1 and #2. It could be initialized to random noise or, as I did, to a copy of the content image itself. The latter ensures much faster convergence than the former. Of course, we will need to dampen the content loss or increase the importance of the style loss. Either way. If we didn’t do that, the content component would annihilate the style one, given that we start with an image which, by definition, already contains the content we are optimizing against.

To handle this scenario I wrote a custom PyTorch Dataset, inspired by the fastai Data Block API. Each item in the dataset is a tuple with 3 tensors. To easily inspect their contents, I have overridden the `__repr__` function within the Dataset class, so that when I `print` the object, something useful gets displayed. Like this

size = 300
padding = 40
rgb = MakeRGB()
resized = ResizeFixed(size)
tobyte = ToByteTensor()
tofloat = ToFloatTensor()
norm = Normalize(imagenet_stats, padding)
tmfs = [rgb, resized, tobyte, tofloat, norm]
train_ds = StyleTransferDataset(dataset_path, train_test='train', transform=tmfs, sample=0.02)
valid_ds = StyleTransferDataset(dataset_path, train_test='valid', transform=tmfs, sample=0.5)
print(valid_ds)
print(train_ds)
# output
Valid dataset: 43 items
Item: class 'tuple' of 3 class 'torch.Tensor'
Item example: 'input':torch.Size([3, 380, 380]),'content':torch.Size([3, 380, 380]),'style':torch.Size([3, 380, 380])
Train dataset: 1626 items
Item: class 'tuple' of 3 class 'torch.Tensor'
Item example: 'input':torch.Size([3, 380, 380]),'content':torch.Size([3, 380, 380]),'style':torch.Size([3, 380, 380])

There are a few interesting things happening here.  Lines 4-8 list the transformations we apply to the 3 images in each tuple in the pre-processing phase. In the following exact order, a `PIL.Image` goes through:

  • Line 4:  `MakeRGB()` converts the input into RGB format (sometimes PNGs have 4 channels and this needs to be fixed).
  • Line 5: `ResizeFixed(size)` resizes the input to 300 x 300.
  • Line 6: `ToByteTensor()` converts the PIL.Image object into a `torch.ByteTensor` (dtype: torch.uint8) and moves the channel axis from the tail to the head of the tensor (from `300 x 300 x 3` to `3 x 300 x 300`) , the format PyTorch accepts.
  • Line 7: `ToFloatTensor()` converts the Tensor to float and divides by 255 to squeeze pixels’ values in the range 0-1.
  • Line 8: `Normalize(imagenet_stats, padding)` applies the standard ImageNet normalization (subtract mean + divide by std) and pads the image with 40 pixels on each side. This returns the final  3 x 380 x 380 shaped tensor. I add padding as, during training, some ugly artifacts appear close to the contours of the image. I still haven’t figured out why this is the case, so the easiest fix was to artificially increase the size of the image and then remove padding in the post-processing phase, to get back a `3 x 300 x 300` shaped tensor.

To be fed into a NN, both `valid_ds` and `train_ds` need to be wrapped into PyTorch DataLoaders. A DataLoader segments the dataset into batches and ships them to the GPU for processing. 

bs = 4
dataloaders = {'train': DataLoader(train_ds, batch_size=bs, shuffle=True),
              'valid': DataLoader(valid_ds, batch_size=bs)}
inputs, contents, styles = next(iter(dataloaders['train']))
print(f'(input, content, style) = {inputs.shape}, {contents.shape}, {styles.shape}'):
# output
(input, content, style) = torch.Size([4, 3, 380, 380]), torch.Size([4, 3, 380, 380]), torch.Size([4, 3, 380, 380])

The architecture

We got the data part covered. Now we have to figure out an appropriate neural network architecture for our purpose. Once again, the idea is to feed the Gatys & al loss function with the output of a model (not with a random image) and to get the model to learn the style. For this to work, the NN has to ingest a `3 x N x N` tensor and produce a tensor of the same shape. Image in, image out.

Which architectures are able to do that?  U-shaped ones.

They consist of adding, to a standard convolutional contracting path, an expanding upward path which ends up generating an output with the same size as the input. The contraction of the downward path is achieved via pooling layers and convolutions with increasing channels, which squeeze spatial information and expand the feature set. The expansion of the upward path is achieved via upsampling layers and convolutions with decreasing channels, which increase the resolution of the outputs and shrink their channel depth. The first and most famous architecture among this family of models is U-net, introduced by the 2015 Ronneberger paper U-Net: Convolutional Networks for Biomedical Image Segmentation. Compared to the above general architecture, this network adds concatenations of the middle-layer outputs of the expansive path with the same level outputs from the contraction path. This concatenation trick makes U-net a precursor of ResNet, which was only published later at the end of the same year and which follows the same principle. Below the U-net architecture from the original paper.

U-net from the 2015 paper by Ronneberger et al.

I tried on my own to deviate from the Johnson et al strategy, building a U-net with a pre-trained ResNet18  downward backbone. Quality of results and training speed were not even close to the paper’s Transformer Net (TN), so I eventually discarded my approach completely and opted for TN. Let’s see what’s in it.

This is the network’s implementation, which I shamelessly stole from the official PyTorch examples repo (below for reference, pasted without helper classes `ResidualBlock`,  `UpsampleConvLayer` and `ConvLayer`). Note that the paper mentions the usage of BatchNorm layers, whereas this implementation makes use of Instance Normalization, as suggested by Johnson himself in the official paper repo. Quoted from the README:

“This repository also includes an implementation of instance normalization as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization by Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. This simple trick significantly improves the quality of feedforward style transfer models.”

The difference between Batch Norm (BN) and Instance Norm (IN) is the following. The former (BN) calculates 1 mean and 1 std per batch and normalizes all the images in the batch with those values. The latter (IN) calculates 1 mean and 1 std per channel per image. So each channel in each image is normalized independently of the others. I found this resource extremely helpful to better understand the difference.

class TransformerNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Initial convolution layers
        self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = torch.nn.ReLU()
    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y

The above Transformer Net architecture is graphically illustrated below. The visualization shows the 12 series of activations of the NN, with the first 6 constituting the downward path and the last 6 the upward one. As you can see a `3 x 380 x 380` image gets ingested and a tensor of the same shape is produced. Each activations’ block is represented alongside with its size (`channels x height x width`) on top, and with the NN’s layers generating it (I respected the layers’ names used in the `TransformerNet` python class to easily link image and code).

We have data and architecture. Let’s move to the last bit.

The loss function and the training phase

As already mentioned multiple times, this is the same loss proposed by Gatys et al. I won’t go over it again in detail as I did already in this previous post. In short, it is composed of a content part and a style part, with each one measuring how closely the output image captures content and style from the respective inputs.

The process is the following:

  1. Stick the style image through a pre-trained CNN (I used VGG19).
  2. Calculate the Gram Matrix for each one of the convolutional layers’ activations.
  3. Do #1 and #2 for the Transformer Net’s output image.
  4. Calculate the MSE between each pair of Gram Matrices from #2 and #3.
  5. Sum #4. We just obtained the style loss.
  6. Stick the content image through VGG19.
  7. Extract the activations from, say, the 3rd convolutional layer from the top.
  8. Do #7 for the Transformer Net’s output image.
  9. Calculate the MSE between #7 and #8. We just obtained the content loss.
  10. Sum style and content loss (optionally add the total variational loss too), with relative weights, and obtain the total loss.

You can better see this in action with some code. Here is the function responsible for computing the total loss. Below, instead, the simplified version of the training loop (original here). I have removed `print` statements and refactored a bit to highlight the core mechanics of the optimization. 

def train(self, num_epochs=1):
    for epoch in range(num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train': self.model.train() 
            
            for i, (inputs, contents, styles) in enumerate(self.dl[phase]):
                self.inputs = inputs.to(self.device)
                contents = contents.to(self.device)
                styles = styles.to(self.device)
        
                self.vgg(contents)
                self.content_act = [o.features.clone().detach_().to(self.device) for o in self.act]
                self.vgg(styles)
                self.style_act = [o.features.clone().detach_().to(self.device) for o in self.act]
                self.opt.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    self.outputs = self.model(self.inputs)
                    self.vgg(self.outputs)
                    self.input_act = [o.features.clone().to(self.device) for o in self.act]
            
                    self.loss, self.content_loss, self.style_loss, self.tv = self.combined_loss()
                    
                    if phase == 'train':
                        self.loss.backward()
                        self.opt.step()

The following chart visualizes the evolution of the loss function, with all its components, over 400 batches (200 data points in total as I average over 2 batches). Specifically, this is an example taken from my latest Picasso’s experiments

    A couple of things are worth mentioning:
  • `tv` stands for total variational loss. It is insignificant as I intentionally almost zeroed its contribution. It is not there in the original Johnson et al implementation and adding it does not bring any tangible value.
  • The three losses (`content`, `style`, `tv`) are of the same order of magnitude. This doesn’t happen by chance. By nature, they are actually on completely different scales. I take care of that, calculating content2style and `content2tv` ratios, here. Those ratios are then used, during the training phase, to adjust the losses and make sure they are comparable. This allows me to better control the impact of each one of the three components on the total loss, as I can weigh them in the sum with human-sized multipliers (order of magnitude of the 10s). 
  • The `content` loss is almost stationary. That is due to the fact that I purposedly down-weighted its contribution in the mix, as my input to Transformer Net is the content image itself. So, technically, the content is already there from the start, and its importance in the global optimization loop shrinks. For the above chart, I had used a `content_weight=1.5`.
  • The `style` loss is the one playing the biggest role. I increased its weight by a factor of 10 to better capture the Picasso’s twist. Globally, the total loss goes down, but it is entirely driven by the style.
  • As you might have noticed, I only trained that network for 400 batches. Considering a batch size of 4, that amounts to 1.6k images, just 2% of the 80k-sized COCO dataset (12 minutes training time). Not even close to one epoch. I admit this seems crazy, but below you can check out the results, at inference time, on 3 pictures from the dataset itself. Except for the dark artifact on the left side (solved with padding), the Picasso-styled pics are not bad. Of course, they can be improved but overall I find them quite satisfying. I tried training for longer but I things got always worse. Admittedly, I don’t know why this is the case, as Johnson goes for 2 full epochs on the same COCO dataset, getting a more visually appealing result, not super far from mine though. I will have to investigate further.

This is the end of our Fast Style Transfer journey.

Once again, go grab a brush on visualneurons.com

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading