Skip to content

Benchmarking TorchVision ResNet18 on EC2 NVIDIA GPU with TensorRT and Amazon SageMaker Neo

Reading Time: 9 minutes

Disclaimers

  1. You can find the code used in this post here.
  2. In this post, we work with ResNet18. A very “small” network. This natively makes it already fast and performant without any optimization on GPU. Therefore, I suspect this is also the reason Neo doesn’t provide any significant latency benefit. It would have made sense to benchmark a much bigger model, but the purpose of this post was to follow up on my previous one, so I stuck to my original choice.

TLDR

All scenarios running on NVIDIA GPU:

  1. ResNet18 ➡️ TorchScript: 1.85 ms
  2. ResNet18 ➡️ TorchScript ➡️ TensorRT @FP16: 1.1 ms
  3. ResNet18 ➡️ TorchScript ➡️ SageMaker Neo (TensorRT @FP32): 1.82 ms

Context

This post is a continuation of the experiments I conducted and summarized here around deploying and benchmarking ResNet18 to SageMaker endpoints on GPU and AWS Inferentia. One of the comments I received on that work was: “It seems the combo SageMaker Neo + Inferentia wins hands down. You are not performing any GPU-related model optimization, though. What happens if you convert the graph to TensorRT? Does the picture change?”

The context of the question is the following. What makes Neo + Inferentia hard to beat is the 10 ms latency at a 0.297 USD/h price point. TorchScript on GPU is pretty fast. 14 ms. But the price point of a g4 (0.736 USD/h) is ~2.5 higher than Inferentia’s, making the latter 3.5x more financially appealing overall. If we drastically reduced the latency on GPU though, things might change. NVIDIA TensoRT (TRT) is almost miraculous (at times). Say we’d be able to bring latency down from 14 to 4 ms. At the same g4 price point, this scenario would come on top in terms of cost-per-prediction effectiveness. So, the question is, can TRT fill the gap?

I limited myself to converting the model to TensorRT and FP16, which sped up SageMaker execution from 14 to 10 ms (same as Inferentia). Not enough to make a concrete difference though. I tried INT8 quantization but it didn’t work out of the box and I had no patience to investigate the problem further. Shall we give up on TensorRT then?

Not really. The thing is that CloudWatch (CW) model latencies are not completely comparable in our case. That’s because CW tracks the time taken for the SageMaker container to process the incoming request. But this includes the payload’s preprocessing (resizing, normalizing the image, etc), and that changes between endpoints. In the TRT case, the input is a base64-encoded image, which requires more work to be converted to something the network can process, compared to the bytes objects I sent to the Inferentia machine. We could run the preprocessing outside of the endpoint, but that wouldn’t solve the problem entirely, since different endpoints take inputs in different formats and some sort of custom payload processing would be needed anyway, making the comparison a bit arbitrary.

Plan

Then I told myself: what if I get rid of SageMaker altogether and execute my models on a GPU offline? Here we’d have the same hardware with different optimizations. Not necessarily a better way of benchmarking, given in real life we’d be interested in SageMaker’s latency and not in offline performance, but arguably fairer, at least in this situation. Sure, we would break the Neo + Inferentia combo, e.g. we’d have to test Neo on GPU, but that’s actually something I wanted to experiment with anyway, so an additional reason to go down this route.

All right, here is what we are going to do. We’ll fire up a ml.g4dn.xlarge EC2 instance, and then time the execution of the following three networks on the same NVIDIA T4 GPU:

  1. Save an Imagenet-pretrained ResNet18 to TorchScript
  2. Convert #1 to NVIDIA TensorRT FP16
  3. Compile #1 with SageMaker Neo and ml.g4dn.xlarge as target hardware. Download the compiled model and execute inference locally

1️⃣ Bare TorchScript on GPU

The first is pretty straightforward. It goes like this👇

# check the linked jupyter notebook for function definitions (`input_fn`) and imports
resnet18 = models.resnet18(pretrained=True)
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())
input_batch = input_fn("cat.jpg", dtype="fp32")

%%timeit -r 10
with torch.inference_mode():
    _ = trace(input_batch)

> 1.85 ms ± 18.6 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)

Not bad for pure TorchScript on GPU.

2️⃣ NVIDIA TensorRT @FP16 on GPU

How does TensorRT compare to that? To answer this question we need to convert the model to TRT, which we can do via the torch_tensorrt library. This is when Docker takes the stage. You really don’t want to bang your head on the exact combo of python packages to get TRT to run. The Dockerfile is straighforward. We use the NVIDIA PyTorch image as is, adding boto3 and sagemaker (just because we also want to deploy the TRT model to a SageMaker endpoint).

FROM nvcr.io/nvidia/pytorch:22.10-py3

ENV DEBIAN_FRONTEND=noninteractive 
RUN apt-get update
RUN pip install boto3
RUN pip install sagemaker

A couple of important notes before moving forward:

  1. You can run model #1 in this same container. The dependencies to execute #2 are the same as #1.
  2. As I have illustrated in previous posts of mine, I like to orchestrate Docker services with docker-compose. It’s arguably an overkill, considering there is nothing much to orchestrate here. Literally one single container running. I have taken the habit of doing so regardless, as I really like the docker-compose way of handling things.
  3. Even if the application is dockerized, the GPU is still a piece of hardware outside the container. Hardware and, most importantly, related CUDA drivers. The specific 👆 22.10-py3 image ships with CUDA 11.6, which means you need to have the 11.6 drivers installed and functioning on your host machine. I hear you thinking “does this matter?”. Unfortunately yes, because to execute #3, aka the Neo compiled model, we are gonna need CUDA 10.2, which means you’ll have to separately install 10.2 drivers on the host (EC2 here). In my case, it was sufficient to install both. They didn’t conflict with each other, and, without any tinkering, the two separate containers were able to access both drivers as intended. For reference, this is a nice post on how to install several coexisting CUDA Toolkits on the same machine, if you run into trouble. Also for reference, this is the output I get when checking my CUDA installations on EC2 👇

To setup the environment, we just build the image, spin up a container, attach a VSCode session to it, and we are ready to go.

> cd docker_torch_trt
> docker-compose build
> docker-compose -f docker-compose.yaml -f docker-compose.gpu.yaml run --rm -d torch_trt

The juicy part, is literally in these 4 lines of code, in which we pull a ResNet18 model from torchvision and compile it to TensorRT, enabling FP16 precision.

resnet18 = models.resnet18(pretrained=True).eval().to("cuda")

trt_model = torch_tensorrt.compile(resnet18, 
    inputs = [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.half)],
    enabled_precisions = {torch.half}, # Run with FP16
)

Once done, inference works exactly the same as for the previous case

input_batch = input_fn("cat.jpg", dtype="fp32")

%%timeit -r 10
with torch.inference_mode():
    _ = trt_model(input_batch)

> 1.06 ms ± 3.99 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

Holy moly! 1.06 ms 🤯. That’s almost 2x faster than plain TorchScript on GPU. Kinda makes sense I’d say, considering the massive optimization happening under the hood. Still mind blowing to see a CNN run at ~1 ms latency. Keep in mind that we could push this even further with INT8 quantization, but I ran into roadblockers on this one and I sticked to FP16.

3️⃣ Neo-compiled model on GPU

First thing we have to do here is compile ResNet18 from torchvision with Amazon SageMaker Neo. We already did it in this previous post, but let me paste the relevant code with comments below. 👇 is literally everything we need.

import torch
import torchvision.models as models
import tarfile
import sagemaker, boto3, os
from sagemaker.pytorch.model import PyTorchModel

# PULL MODEL FROM TORCHVISION, TRACE AND SAVE IT
resnet18 = models.resnet18(pretrained=True)
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())
trace.save("model.pth")

# DEFINE AWS RELATED VARS
region = "eu-west-1"
os.environ["AWS_DEFAULT_REGION"] = region
role = "arn:aws:iam::257446244580:role/sagemaker-icevision"
sess = sagemaker.Session(boto_session=boto3.Session(region_name=region))
sm_runtime = boto3.Session().client("sagemaker-runtime", region_name=region)

# TAR MODEL ARTIFACTS AND UPLOAD TO S3
with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")
    # THE SERVE SCRIPT IS TECHNICALLY NOT NEEDED GIVEN WE ARE
    # JUST COMPILING THE MODEL, NOT DEPLOYING IT TO SAGEMAKER
    f.add("serve_neo.py")
model_uri = sess.upload_data(path="./model.tar.gz", key_prefix="neo_pytorch")

# DEFINE PYTORCHMODEL WITH THE SAGEMAKER SDK
pytorch_model = PyTorchModel(
    model_data=model_uri,
    role=role,
    # SAME COMMENT AS BEFORE. `entry_point` IS A POSITIONAL 
    # ARGUMENT THOUGH, SO IT NEEDS TO BE PROVIDED ANYWAY.
    # YOU COULD PASS AN EMPTY STRING IN THEORY. HAVEN'T CHECKED
    entry_point="serve_neo.py",
    framework_version="1.12",
    py_version="py3",
)

# COMPILE THE MODEL WITH SAGEMAKER NEO.
# NOTICE THE `target_instance_family="ml_g4dn"`.
# WE WANT TO COMPILE FOR THE `g4` FAMILY OF 
# INSTANCES, AS THIS IS THE HARDWARE WE ARE RUNNING ON.
pytorch_model = pytorch_model.compile(
    target_instance_family="ml_g4dn",
    input_shape={"input0": [1, 3, 224, 224]},
    output_path=model_uri,
    framework="pytorch",
    framework_version="1.12",
    role=role,
    job_name=f"neo-pytorch-{int(time.time())}",
)

pytorch_model.compile triggers a Neo compilation job which, once done, should look like the following on the console👇

The actual output of the compilation is a tar.gz file in S3…

… which we download to EC2 and extract

Before going ahead there is something interesting to highlight about Neo. First it is important to keep in mind that Neo is not a compilation engine per-se. It is a service which delegates compilation to the appropriate engine, based on the model and the hardware. Given our target hardware is a g4 instance featuring a NVIDIA GPU, it is probably not a surprise that Neo asks TensorRT to take over and run the compilation. The proof is inside the 1435857361_0_Neo.json file. This is what it (partly) contains👇. Notice that

  • line 21 confirms that the compiler is TRT
  • lines 45-46 hint at the fact that the model was compiled under full floating point precision (FP32).

Does the fact that Neo used TRT under the hood mean that we’ll have the same performance as model #2 obtained via torch_tensorrt? In scenario #2, we compiled with FP16 though, so Neo should be a little slower. We’ll check in a minute.

First, we have to figure out how to load and run the Neo model. The answer is the python DLR library. Here is how to install it on cloud instances and here how to execute inference with it. Once again, as well as for torch_tensorrt before, the best approach to tackle the task is Docker. I am pretty much copy-pasting the installation instruction inside the Dockerfile, masking sure to add the torch stack too. The result is the following.

FROM nvcr.io/nvidia/tensorrt:19.12-py3

ENV DEBIAN_FRONTEND=noninteractive 
RUN apt-get update
RUN pip install --upgrade pip
RUN pip install boto3
RUN pip install sagemaker
RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu102

RUN git clone --recursive https://github.com/neo-ai/neo-ai-dlr && \
    cd neo-ai-dlr && \
    mkdir build && \
    cd build && \
    cmake .. -DUSE_CUDA=ON -DUSE_CUDNN=ON -DUSE_TENSORRT=/opt/tensorrt && \
    make -j4 && \
    cd ../python && \
    python3 setup.py install --user

Then, to setup the environment, we just build the image, spin up a container, attach a VSCode session to it, and we are ready to go.

> cd docker_neo_dlr
> docker-compose build
> docker-compose -f docker-compose.yaml -f docker-compose.gpu.yaml run --rm -d neo_dlr

To run inference, we import the library, load the model, prepare the input image (notice it is a numpy array, not a torch tensor), and call model.run on it, like so

from PIL import Image
import numpy as np
import dlr

def input_fn_neo(path):
    image = Image.open(path).convert("RGB")
    image = image.resize((224, 224))
    x = np.array(image).astype('float32') / 255.
    
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    x = (x - mean) / std
    
    return np.swapaxes(x, 0, -1)[None] # shape = (1, 3, 224, 224)

neo_model = dlr.DLRModel('./dlr_model','gpu')
input_batch = input_fn_neo("cat.jpg")

%%timeit -r 10
_ = neo_model.run(input_batch)

> 1.82 ms ± 58.9 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)

🏃‍♂️💨 Results

Let’s put things together. All three scenarios are running on NVIDIA GPU.

  1. ResNet18 ➡️ TorchScript: 1.85 ms
  2. ResNet18 ➡️ TorchScript ➡️ TensorRT @FP16: 1.1 ms
  3. ResNet18 ➡️ TorchScript ➡️ SageMaker Neo (TensorRT @FP32): 1.82 ms

A couple of comments:

  • TRT @FP16 (#2) is ~2x as fast as its non-optimized counterpart (#1)
  • Neo’s results (#3) are weird. Several sources suggest that for small models (such as ResNet18) the performance improvement compared to a baseline (#1) is negligible. This makes sense, but it is counterintuitive in our case, given Neo still invokes TRT under the hood. It is an FP32-precision compilation, which I had expected to fall between #1 and #2 (FP16). Instead, it matches #1, as if TRT had no effect. Also, I tried recompiling #2 with TRT @FP32 (with torch_tensorrt) to check if full precision was indeed the root cause. The benchmarks returned 1.25 ms, which is the number I thought Neo would produce. Slower than FP16, but still faster than no optimization whatsoever. For the record, it is important to notice that #2 and #3 are being executed on two different versions of CUDA (11.6 VS 10.2 respectively). I am not sure if that matters but it’s still a difference. Maybe Neo fails to completely delegate to TRT and adds a twist to the compilation, undermining the final results. Not sure. A potential test could be to rerun this analysis on a much bigger network (ResNet 152?) and check whether the same trends are respected.

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading