Skip to content

Blurry faces: Training, Optimizing and Deploying a segmentation model on Amazon SageMaker with NVIDIA TensorRT and NVIDIA Triton

Reading Time: 13 minutes

Note: Link to folder with code and models’ artifacts. Here is the notebook to follow along with the post. Also, special thanks to Jiahong from NVIDIA, who patiently guided me through some hiccups on the road.

Context

A couple of months ago I started working on an application to blur human faces in photos. Responsible AI is something I care about a lot. Unfortunately, as far as I can tell, it’s not a top concern for everybody in the ML world, and I thought that putting some effort into a privacy-first computer vision model might be a good contribution in the right direction.

Everything started with this tweet from HuggingFace’s CEO Clement Delangue.

I had wanted to work on this one for quite some time. I just needed a little push and I guess the tweet did it.

Sure enough, I opted for HuggingFace Spaces to make my experiments public. I started with an off-the-shelf solution from the face_recognition library, which allowed me to run a CNN model on the input image to locate faces and then blur them. Very easy, but slow and not visually appealing, given the result consisted of a bunch of blurred rectangles (the bounding boxes drawn around faces).

I first tried to address the slowness part, by switching to Kornia. Using their YuNet architecture, I managed to decrease latency by 20x. From an average of 4.1s/image to 0.2s/image. Impressive.

The visuals were still not great though. A sharp rectangle around a face is far from ideal.

A potential solution to this problem is to switch from face detection to face segmentation, e.g. a network capable of classifying single pixels as human-face or not. 

Training and deploying such a segmentation model will be the objective of this post.

Let’s dive in.

What we’ll learn

  1. Train an image segmentation model (UNET) using IceVision and a sample of the amazing face synthetic dataset from Microsoft.
  2. Convert #1 to TorchScript.
  3. Deploy #2 to HuggingFace Spaces.
  4. Deploy #2 to an Amazon SageMaker real-time endpoint. This will be the model we’ll benchmark latency against.
  5. Convert #1 to ONNX and then to TensorRT (TRT).
  6. Deploy the TRT model to SageMaker on top of NVIDIA’s Triton inference server.
  7. Check the performance improvements of #6 compared to #4.

Dev environment: where to execute the code and how

I run all my experiments on an Amazon EC2 g4dn.xlarge instance, powered by NVIDIA T4 Tensor Core GPU. I SSH into it via VSCode and enjoy the full IDE experience running code on AWS (here a previous post of mine on how to do just that).

The cloud is not sufficient to solve all our pains though. We’ll use a lot of python libraries, most of which are very picky when it comes to interacting with specific versions of each other, especially on CUDA. To avoid any headaches, we’ll run all our pipeline, from training to deployment, into a Docker container, starting off from an official NVIDIA image. If that sounds scary to you, that’s normal. It’s always scary until it works. Turns out that VSCode makes this whole “write and run code inside a container” thing extremely easy. I wrote about it here and I cannot recommend this strategy enough if you can afford to.  

Here are the Dockerfile, and the docker-compose YAML files. Docker-compose is not strictly necessary to get things to work. It adds a layer on top of pure Docker, but I prefer its concise syntax, and how it makes accessing the NVIDIA GPU from inside the running container very easy.

The only thing you have to do is:

>>> cd into_folder_with_dockerfiles
>>> docker-compose build
>>> docker-compose -f docker-compose.yaml -f docker-compose.gpu.yaml run --rm -d icetrt

Attach a VSCode session to the running container and use VSCode normally, enjoying the benefit of running in a completely isolated environment on top of a GPU-powered machine.

Now that you have a safe place to execute code, feel free to follow along in this notebook as well. It covers the same content as the post, just in a code, mostly.

The dataset

Every ML project starts with a dataset. So the first step was to go hunting for a face segmentation one. Easier said than done. Almost nothing usable out there. Until I stumbled upon this true gem from Microsoft. A completely synthetic dataset of 1M human faces picturing people from all sorts of angles, age, gender, and race, dressed up in all sorts of ways, and fully annotated to perform object+keypoint detection and semantic segmentation. A goldmine.

The Github repo offers a sample of 1k images as well, which is more than sufficient for our purposes.

The FaceSynthetics dataset ships with segmentation masks with 30 classes in total: HAIR, SKIN, GLASSES, CLOTHING, …, BACKGROUND. We are actually just interested in face vs rest, so I relabelled the masks aggregating BACKGROUND + NECK + HAIR + CLOTHING + HEADWEAR + IGNORE into background (pixels’ value = 0) and tagging the remaining classes as face (pixels’ value = 1).

Training the model in IceVision

Once collected the data and fixed the labels, we need to train a model on it. IceVision makes this step super easy. You can check the IceVision section of the notebook for more details, but in a nutshell, it all boils down to:

  1. Writing a custom parser to read the data and turn it into a format a NN can process. Those are IceVision records. They look like this when visualized with masks👇
  2. Split records into training and validation sets.
  3. Creating datasets and augmentations (more on the latter below).
  4. Picking a model (we’ll use a UNET with a ResNet34 backbone). Defining dataloaders and a fastai learner.
  5. Train the model.
  6. Check results👇 (Not bad!)

A note on augmentations: if you have ever worked on a CV project, you know that, if done right, they represent a real superpower. In my case, they were critical to the success of the project. Here is the problem I had. The FaceSynthetics dataset is constituted of 512x512 RGB shots showing a single person pictured in the center. Pretty great but what if I wanted to run inference on crowds, e.g. multiple faces of different sizes? I tried training a basic NN and it was a disaster. What I noticed was:

  • the model was able to pick the faces in the foreground, the larger ones. No chances it’d identify the smaller ones in the background.
  • the model was very sensitive about the size of the inference image. It had been trained on 512x512 pics. It worked great on images of the same size with a single big face. It also worked great on images of 1024x512 with 2 faces in the foreground, or 1536x512 with 3 faces in the foreground. You are properly seeing where this is going. The network had learned the aspect ratio between a face and what surrounds it. This was a clear indication I was not using augmentations wisely. I had gone for the default augmentations suggested by IceVision (first screenshot below) but it was not sufficient. What I needed was a way to teach the model that an image might contain multiple faces of different sizes. I opted for something more aggressive, e.g. tfms.A.ShiftScaleRotate(scale_limit=[-1.0, 1.0], rotate_limit=10, shift_limit=0.0625, p=1, border_mode=3), which gives the results in the second screenshot below. Would it be enough for the network to “see” those examples during training to generalize better? I tried, and this time, BOOM. It worked. This is how I got the inference results in the third screenshot.
Basic augmentations. With these, the model was not able to learn that multiple faces might be present in the picture.
More advanced augmentations. These allowed the model to generalize better and produce the👇 inference results.
Inference results obtained by training the network with more advanced augmentation techniques.

Yet another note: If you are scratching your head 🤔 wondering how a network trained on 512x512 images can run on anything different than that without throwing a massive amount of size-mismatch errors, I have been there, and you need to see👇

Convert the model to TorchScript

Ok, we have a trained model. Now what? How can we use it outside of a notebook?

A completely reasonable answer would be to Dockerise our environment (basically what we have already done) and run inference on the fastai learner wherever we want. Thing is, do we actually need all the libraries we used for training, namely IceVision and fastai? No, we don’t. If our NN architecture allows to (and UNET does), we can convert the pure PyTorch model to TorchScript. What is that?

As per the documentation (and another useful resource): “TorchScript is an intermediate representation of a PyTorch model (subclass of nn.Module) that can then be run in a high-performance environment like C++. It’s a high-performance subset of Python that is meant to be consumed by the PyTorch JIT Compiler, which performs run-time optimization on your model’s computation.”

Not sure I understood all of it 🤷‍♂️, but what matters is that it sounds cool to me. More than that it is fast, and allows us to trim our dependency list to include torch only. Literally, that would be it.

At inference time, we can simply load the TorchScript artifact into a PyTorch object and run (preprocessed) images through it. By “preprocessed” I mean resized to an appropriate shape, normalized using the ImageNet stats, and converted to tensor.

What’s an “appropriate” shape here?

That depends on how we convert to TorchScript.

What happens is that we send some (dummy) data to the model, and torch will trace the tensor inside the network, guessing what the graph looks like.

dummy_inp = torch.randn([1, 3, 512, 512]) 
torch.jit.save(torch.jit.trace(model, dummy_inp), 'model.pt')

For simplicity, I’ll always use square images in all my inference pipelines here. Technically it’s not always needed, as TorschScript can accept some rectangular shapes but I prefer to avoid the headache in the first place and I make sure to pad input images accordingly. You can find the logic to do so (and then remove padding as an inference post-processing step) inside the remove_padding function in the notebook.

Once you take care of your input, executing the model is as simple as

tensor_img = torch.randn([1, 3, 512, 512])

model_jit =  torch.jit.load("model.pt")
with torch.inference_mode():
    x_from_jit = model_jit(tensor_img)

Also, make sure to always check that the outputs of the original fastai learner and the TorchScript model are the same. They must be.

Deploy the TorchScript model to HuggingFace Spaces

Given all the trouble I went through to make face-blurring results more visually appealing, at least compared to the rectangular bounding boxes I started from, I thought it’d definitely make sense to update my HuggingFace Gradio Space with the latest TorchScript segmentation model. You can play with it here. The serving logic is simple (thanks Gradio).

  • I preprocess the input image, adding padding if necessary to make it squared.
  • Run it through the model.
  • Apply the face masks to the original image (removing padding if it was there).
  • Blur faces.
  • Return.

Convert the model to ONNX

Note: From this moment on, I have re-implemented (shamelessly copy-pasted with edits) this great notebook from the AWS team.

NVIDIA TensorRT is an SDK for high-performance deep learning inference, includes a deep learning inference optimizer and runtime that delivers low latency and high throughput for inference applications.

In order to convert our model to TensorRT, the first thing we need is its ONNX version. Exporting to ONNX is useful as a standalone step actually, as we could use the artifact as is to run inference. That kind of inference would be framework agnostic (doesn’t need any torch installed) and generally faster than the previously generated TorchScript version.

PyTorch allows exporting to ONNX in a fairly easy way. Very similarly to what we did for TorchScript, we have to provide an input tensor for the framework to run the model onto and track its inner states. The code is the following.

img = torch.rand([1, 3, 512, 512])
eval_model = learn.model.eval()

torch.onnx.export(
    eval_model,
    img,
    "model.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size", 2:'height', 3:'width'}, 
                  "output": {0: "batch_size", 2:'height', 3:'width'}},
)

As we did for TorchScript it is useful to run inference on the ONNX artifact and check if the results are in line with what we got (it is the case; see previous screenshot).

As for the accepted shape of the input, same story as TorchScript. I have had hard times feeding rectangular images (it works sometimes to be honest but not all the time, despite the dynamic axis), so I have resolved myself to padding whenever I needed to preserve the aspect ratio. It works quite neatly.

Convert the model to TensorRT

Now that we have exported to ONNX, we are ready for the final TensorRT conversion.

This is when you’ll thank yourself for running inside a Docker container. Creating the TensorRT engine is quite tricky and the top advice you’ll find on the NVIDIA forums to people trying to install/build TensorRT dependencies on their local machines is to stop and just use NVIDIA (NGC) containers. I was one of those folks asking for help. At some point, I just accepted the advice and never looked back.

A couple of decisions must be made at this point. 

Among those, the shapes of the inputs and the outputs (pretty much constrained by the choices we have made during the ONNX step), and the network-level precision (FP16, INT8, etc). I went for a static 1x3x512x512 input and for a quite aggressive INT8 precision. The latter needs some thought first, as it generally impacts accuracy. It was not the case for me, as TensorRT inference results were completely in line with the original model. No loss in accuracy at all. We’ll use the trtexec command-line wrapper to create the serialised TensorRT engine from ONNX. You can find here the list of arguments you can pass to trtexec. I didn’t experiment with sparsity but I plan to, as it promises further latency speed ups.

Finally, the following is the command I run for the TensorRT conversion:

trtexec --onnx=model.onnx \ 
        --saveEngine=model.plan \ 
        --explicitBatch \
        --minShapes=input:1x3x512x512 \
        --optShapes=input:1x3x512x512 \
        # yes, minShapes=maxShapes, so the model will accept only 1x3x512x512 inputs
        --maxShapes=input:1x3x512x512 \
        --int8 --verbose | tee conversion.txt

The result is:

  • a conversion.txt file with the (verbose) logs of the operations performed and
  • a model.plan with the TensorRT engine

We have now all the pieces to move to the next step: deployment to SageMaker.

Baseline endpoint: deploy the TorchScript model to SageMaker

TensorRT is fast. Blazing fast. TensorRT + Triton is even faster. I like fast, but a number alone doesn’t mean much if not put in context. How much faster it is compared to a reasonable baseline?

Let’s find a reasonable baseline then. Ours is the TorchScript model deployed to a NVIDIA GPU-powered SageMaker endpoint. The serving code is crazy simple. Load the model from torchscript, run the input though it, return the mask. Period. Here it is.

Given I have literally no dependency except torch, I can entirely rely on an off-the-shelf DLC (Deep Learning Container) provided by SageMaker. We have simplified our lives enough to make deployment a breeze. The code is as follows. Note that we have to upload the model artifacts to S3 beforehand.

from sagemaker.pytorch import PyTorchModel
import sagemaker

sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
role = "ARN of an IAM role with SageMaker and S3 permissions"

model = PyTorchModel(
              role=role,
              name="ice-mask-torchscript",
              sagemaker_session=sagemaker_session,
              model_data="s3://your S3 bucket/model.tar.gz",
              framework_version='1.10',
              py_version='py38',
              entry_point="serve.py",
              source_dir="ice_mask_torchscript",
              )

predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    endpoint_name="ice-mask-torchscript",
)

Once the endpoint is green (In Service) we can check if it can successfully receive and process requests. We use a PyTorchPredictor which neatly handles tensors without any hassle on our side.

How fast it is?

Hold on a sec for this one. We’ll answer it at the very last.

Shifting gears: deploy the TensorRT model to SageMaker on NVIDIA Triton

Note: this is a great post from the AWS and NVIDIA team on what Triton can achieve on SageMaker. Highly recommended.

First things first. What is Triton? It’s an inference server developed by NVIDIA.

From the documentation: “Triton Inference Server streamlines AI inference by enabling teams to deploy, run and scale trained AI models from any framework on any GPU- or CPU-based infrastructure. [… It] supports all major frameworks, such as TensorFlow, NVIDIA® TensorRT™, PyTorch, MXNet, Python, ONNX [… It] supports all NVIDIA GPU-, x86-, and ARM® CPU-based inferencing. It offers features like dynamic batching, concurrent execution, optimal model configuration, model ensemble, and streaming inputs to maximize throughput and utilization.”

Looks promising. Especially when it comes fully integrated with SageMaker, making it even easier to use. Let’s see how it works in practice.

We already have the TensorRT model.plan we created earlier. We are just missing the instructions Triton needs to deploy the serialized engine. Those are included in a config.pbtxt file, and boil down to input and outputs shapes plus additional info we might want to provide to Triton for further optimization, both in terms of latency and throughput. Stuff like model warmup, if and how to set concurrency across GPUs, dynamic batching of requests, etc. The documentation contains everything you need to know about model configuration. Check it out. It’s very comprehensive. Here how my config.pbtxt looks like:

name: "unet"
platform: "tensorrt_plan"
max_batch_size: 1
input {
  name: "input"
  data_type: TYPE_FP32
  dims: 3
  dims: 512
  dims: 512
}
output {
  name: "output"
  data_type: TYPE_FP32
  dims: 2
  dims: 512
  dims: 512
}
instance_group {
  count: 1
  kind: KIND_GPU
}
model_warmup {
    name: "Warmup"
    batch_size: 1
    inputs: {
        key: "input"
        value: {
            data_type: TYPE_FP32
            dims: 3
            dims: 512
            dims: 512
            zero_data: false
        }
    }
}

Are we good?

Not yet. We need to create a directory with the 👈 structure and content, compress it and then upload to S3. Those are the artifacts SageMaker will unpack during deployment and use together with Triton. 

Now we are ready to create a model, an endpoint configuration and an endpoint. You can refer to the notebook attached to this post for the details.

Once the endpoint turns green (In Service) we can test if it is able to accept and process requests. Keep in mind that in this case it is critical to use the right format to package the payload in order to squeeze the best performance out of the endpoint. Specifically, JSON isn’t really a good format for doing large data transfers, especially for computer vision workloads. I experienced a much better performance by simply switching the payload from JSON to binary.

Show us speed! 🏃🏻🚄 💨

We’ll look at model latency in two contexts.

  1. End2end latency measured from my notebook running on an EC2 instance
  2. Latency measured by AWS CloudWatch, e.g. without the networking/payload-exchange overhead

The comparison is between:

  1. The SageMaker endpoint with the TorchScript model (GPU)
  2. The SageMaker endpoint with the TensorRT model deployed on Triton (GPU)

1. End2end: Triton wins by 30% 🏃🏻

By hitting the endpoints repeatedly from my EC2 notebook I got ~200ms for TorchScript and ~140ms (as low as 80ms 🤯 in some cases!) for Triton. That’s a 30% gain. Not bad at all. This includes networking overhead though. What happens on SageMaker alone, inside the endpoint?

2. CloudWatch: Triton wins by 8x 🚄💨🤯

I understand you might not believe me so I am posting screenshots from CloudWatch. What we see is TorchScript at 130ms against Triton at 15.6ms. That’s an 8x gain 😮

Pretty impressive.

Thanks for reading thus far! Happy hacking!

3 thoughts on “Blurry faces: Training, Optimizing and Deploying a segmentation model on Amazon SageMaker with NVIDIA TensorRT and NVIDIA Triton”

  1. Pingback: Blurry faces: a journey from training a segmentation model to deploying TensorRT to NVIDIA Triton on Amazon SageMaker – – GANjeh

  2. Pingback: NVIDIA Triton Spam Detection Engine of C-Suite Labs - Ermanno Attardo

  3. Pingback: NVIDIA Triton Spam Detection Engine of C-Suite Labs

Comments are closed.

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading