Skip to content

Fast Neural Style Transfer: deploying PyTorch models to Amazon SageMaker

Reading Time: 7 minutes

Context

A while back, Gabriele Lanaro and I started working on a web application to perform Neural Style Transfer on images and GIFs. We implemented the famous technique developed by Gatys & al, and visualneurons.com was born. Results were quite decent. Still, what kept bothering me was the ridiculous slowness of the optimization pipeline. The loss minimization loop was painfully long, and spinning up the EC2 instance in the backend plus setting up a websocket to establish communication with the frontend was an excruciatingly time-consuming process.

Not long after our first experiments, I stumbled upon this paper on a novel idea for Fast Neural Style Transfer. The authors proposed something smart. What if instead of running a separate optimization each time style had to be transferred, we trained a Neural Network to learn the style from scratch? That would mean reducing the whole process to simple inference on a model. The technique had already been implemented and documented in several online resources, the official PyTorch examples repo being one of those. I had to give it a shot myself.

This first post is a deep dive into the production section of the project. I will soon work on a writeup for the Deep Learning part too.

As for model serving, once again, Amazon SageMaker (SM) was my platform of choice. The difference with last time though, was that in this case, I had not trained the network on SM. I had taken care of the training part outside of AWS, found myself with the model’s weights and the need to run inference in a web application. Turns out that, even if the network has not been trained in SM directly, it is totally possible to use the AWS service just for deployment purposes. There are a few things to take care of, though. The goal of this post is to walk through all the details of this journey, from the moment the model is saved to disk to the moment we click a button in the frontend and everything just works. Here a demo video of how the service looks like on visualneurons.com.

AWS infrastructure

Before diving into SM details, let’s briefly cover the AWS infrastructural section. I have integrated the Fast Neural Style Transfer functionality into the pre-existing AWS-based visualneurons.com website. The application, which was developed in collaboration with Gabriele Lanaro, already featured the traditional-Gatys-et-al version of Style Transfer, both on images and GIFs. Adding the fast version was quite easy. I had done something almost identical already in my is-this-movie-a-thriller.com experiment. Except that this time an entire image is returned by SM instead of a single class prediction. Below you can see the (pretty standard) diagram of the AWS workflow.

The customer uploads a picture, selects a style and clicks the `Fast Style Transfer!` button. By doing so, it triggers a POST ajax call to API Gateway, which in turn triggers a Lambda function. On its end, Lambda picks the right SM endpoint to invoke, according to the user-selected style, and sends the image over to it. SM returns the stylized picture, which is passed back by API Gateway and rendered in the browser.

SageMaker deployment

Saving the model to disk and uploading to S3

Back to SM now. As mentioned before, SM doesn’t care the model you are deploying has actually been trained inside the platform. At the end of the day, deploying to an endpoint involves spinning up an EC2 instance, creating a Docker image with the model invocation scripts, and making all of this accessible via an API. Nothing else than that. There is no dependency on the training loop having been executed within AWS. What SM needs is the model, e.g. its weights, and a script with all the functions to run inference on it. Those include four main functions which must be provided for the actual deployment to work (we’ll see them in a minute), plus any other self-contained helper code, in case custom transformations are required. As for how to handle the direct inputs/outputs to/of the endpoint, e.g. image to byte array encoding/decoding, this is presumably taken care of either in Lambda or within the JavaScript frontend, or both (we’ll cover this bit too). The first step is to wrap the weights in a zipped folder (`tar.gz`) and upload it to S3, for SM to read them from the endpoint. Conventionally, PyTorch NNs are saved to `pth` files. I stored mine in model.pth (one per style), making sure to differentiate models trained on different styles by uploading them to different S3 prefixes.

S3 “folders” each containing a `model.tar.gz` compressed file

Once I was happy with a model, I created a folder named after the artist I was working with (e.g. picasso or vangogh) and saved my results to `model.pth`. I then run the following python snippet:

import tarfile
style = "vangogh" # or whatever other style
with tarfile.open(f'./{style}/model.tar.gz', 'w:gz') as f:
    t = tarfile.TarInfo('models')
    t.type = tarfile.DIRTYPE
    f.addfile(t)
    f.add(f'{style}/model.pth', arcname='model.pth')
import SM
SM_session = SM.Session()
bucket = "visualneurons.com-fast-nst"
prefix = style
model_artefact = SM_session.upload_data(path=f'./{style}/model.tar.gz', bucket=bucket, key_prefix=prefix)

Which (in the `vangogh` case) uploads the compressed weights to `s3://visualneurons.com-fast-nst/vangogh/model.tar.gz`. The next step is to put together the four functions I was mentioning before. Together with any additional helper routine, those constitute what SM refers to as the entry point script, which defines how SM pre-processes inputs, invokes the model and parses outputs. It is important to understand that, once productionized, each time we invoke the endpoint, we would be implicitly calling those functions. We never do it directly. SM will do it for us. The way we control their behavior is by defining a SM `PyTorchModel` object, to which we pass any useful argument for this purpose. The location of the entry point script and the model’s weights files as an example.

The inference entry point script

Here my custom implementation of the SM inference functions. For reference, these two resources were the ones I found the most useful in the process. Here you can find the SM entry point script in its entirety.

  • `model_fn`
  • `input_fn`
  • `predict_fn`
  • `output_fn`

The `model_fn` function

`model_fn` is in charge of defining the network architecture and loading the S3 weights into it. It takes `model_dir` as a unique parameter, e.g. the location of the `tar.gz` file with weights (model_artefact). Given that we don’t call `model_fn` directly, `model_dir` is passed as an argument to the SM `PyTorchModel` constructor, under the `model_data` parameter.

def model_fn(model_dir):
    device = torch.device("cpu")
    model = TransformerNet()
    with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
        model.load_state_dict(torch.load(f, map_location=device))
    return model.to(device)

The input_fn function

`input_fn` is instead the one taking care of pre-processing the input. It converts the byte array format it is fed with, into a `PIL.Image`. It then applies several transformations to the resulting picture, most notably reshaping and ImageNet normalization. Those are required to prepare the input to be ingested into the neural network. The function returns a dictionary with the pre-processed picture and the original image size. We will need the latter to resize the output back to the initial shape.

def input_fn(request_body, content_type=JPEG_CONTENT_TYPE):
    img = PIL.Image.open(io.BytesIO(request_body))
    item = {'input': img}
    rgb = MakeRGB()
    resized = ResizeFixed(size)
    tobyte = ToByteTensor()
    tofloat = ToFloatTensor()
    norm = Normalize(imagenet_stats, padding)
    tmfs = [rgb, resized, tobyte, tofloat, norm]
    item = compose(item, tmfs)
    
    return {'img': item['input'], 'size': img.size}

`request_body` represents what we send over via JSON to the SM endpoint. `content_type` is also part of the request. It should be used inside `input_fn` to check whether the endpoint got invoked with the right content, ad throw an exception (or simply handle it) if this is not the case. In my scenario, I control all the pipeline end-to-end, so I am not really worried of anything else other than `JPEG_CONTENT_TYPE = ‘image/jpeg’` flowing in. We still have to make sure `request_body` is sent as a byte array, though. This is taken care of first in the browser, via JavaScript, then in Lambda.

function base64FromCanvasId(canvas_id) {
    return document.getElementById(canvas_id).toDataURL().split(',')[1];
}
var inputData = {"data": base64FromCanvasId("content_img"),
                 "style": document.getElementById("style_choice").value}
data = JSON.stringify(inputData)

As you can see from the above couple of lines of code, the frontend reads the `content_img` from its canvas, turns it into `base64` format (e.g. a string) and packages the whole thing in a JSON. This JSON is sent via an API Gateway’s POST method to a Lambda function. Lambda decodes the `base64` format into byte-array and sends the object to SM for prediction. As shown below.

import base64, os, boto3, ast, json 
def format_response(message, status_code):
    return {
        'statusCode': str(status_code),
        'body': json.dumps(message),
        'headers': {
            'Content-Type': 'application/json',
            'Access-Control-Allow-Origin': '*'
            }
        }
def lambda_handler(event, context):
    body = json.loads(event['body'])
    style = body["style"].replace(".png", "")
    styles_map = {"Kand2": 'kandinsky', "Picasso": "picasso", "VanGogh": "vangogh"}
    
    if style not in styles_map.keys(): return format_response(f"Sorry, the {style} is not supported yet!", 400)
    
    image = base64.b64decode(body['data']) 
    runtime = boto3.Session().client(service_name='sagemaker-runtime', region_name='eu-west-1')
    response = runtime.invoke_endpoint(EndpointName=styles_map[style], ContentType='image/jpeg', Body=image)
    r = json.loads(response['Body'].read().decode())
    
    return format_response(r['prediction'], 200)

The predict_fn function

`predict_fn` runs the output of `input_fn` through the trained model. Recall adding a fourth dimension to the front of the image (on top of channels, height and width). The batch axis (`img[None]`). This is required as the network expects 4-dimensional tensors.

def predict_fn(input_object, model):
    img = input_object['img']
    device = torch.device("cpu")
    out = model(img[None].to(device))
    input_object['img'] = out[0].detach() 
    return input_object

The `output_fn` function

We are not done yet. The output of `predict_fn` needs to go through a couple of additional transformations before being sent back to the frontend. Resizing to the original image shape and de-normalizing are examples of that. This is what `output_fn` is for. On top of this, as a very last step, the image is converted back to `base64` format, so that JavaScript can read it seamlessly and render it in a canvas.

def output_fn(prediction, content_type=JSON_CONTENT_TYPE):
    p = prediction['img']
    original_size = prediction['size']
    denorm = DeProcess(imagenet_stats, size, padding, original_size)
    pred = denorm(p)
    if content_type == JSON_CONTENT_TYPE: 
        return json.dumps({'prediction': image_to_base64(pred).decode()})

Deploying the endpoint

Now that we have a complete entry point script, we can put everything together and deploy the endpoint. Just one last bit. We wrote custom inference functions taking a byte array as an input (`input_fn`) and spitting out a JSON as an output (`output_fn`). We need to let SM aware of this choice. The way to achieve this is to define an `ImagePredictor` class inheriting from RealTimePredictor (as shown below), specifying a deserializer and a content type. We can then initialize a PyTorchModel object passing all the relevant arguments, and call the `.deploy()` method on it.

from sagemaker.pytorch import PyTorchModel
role = sagemaker.get_execution_role()
from sagemaker.predictor import RealTimePredictor, json_deserializer
class ImagePredictor(RealTimePredictor):
    def __init__(self, endpoint_name, sagemaker_session):
        super().__init__(endpoint_name, sagemaker_session=sagemaker_session, serializer=None, 
                         deserializer=json_deserializer, content_type='image/jpeg')
pytorch_model = PyTorchModel(model_data=model_artefact, role=role,framework_version='1.0.0', name=style,
                             entry_point='sagemaker_inference.py',predictor_cls=ImagePredictor)
predictor = pytorch_model.deploy(initial_instance_count=1, instance_type='ml.t2.medium')

And here the three endpoints I have deployed so far, covering Van Gogh, Picasso, and Kandinsky.

Now you can check out some blazingly fast Neural Style Transfer on visualneurons.com!

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading