Skip to content

CUDA for Python Programmers

Reading Time: 18 minutes

Introduction

In this post, I go through the content covered by Jeremy Howard in his lecture CUDA for Python programmers. This is a talk he gave as part of the Programming Massively Parallel Processors (PMPP) study group, currently happening on the CUDA MODE Discord server (with videos uploaded to the homonymous YouTube channel). My initial ambition was to merge Jeremy’s first lecture (3rd in the reading group) with his second one (5th in the reading group), titled Going Further with CUDA for Python Programmers, but I ended up only highlighting the main concepts from the latter, without the same level of deep-dive I did for the former preferred skipping it altogether and probably dedicate an entirely separate future post to it.

You can follow along in the repo curated by the study group organizers. Lecture 3 and 5 contain the notebooks Jeremy uses for the talks. As for myself, I executed the code on a g5.2xlarge EC2 machine on AWS (1 NVIDIA A10G GPU), within a dedicated conda environment:

conda create -n cuda python=3.10
conda activate cuda
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install ipykernel
pip install matplotlib
pip install wurlitzer ninja
My GPU/NVIDIA setup on the g5.2xlarge EC2 machine on AWS

Lecture 3: Getting Started With CUDA for Python Programmers

What

The purpose of this lesson is to write two CUDA kernels.

1️⃣ is to convert an RGB image to B&W. 2️⃣ is to multiply two matrices, aka the building block of any deep learning architecture. The latter requires no explanation, just A = B x C. We’ll see what to do in a later section.

The former is worth clarifying. The formula to convert a color image to grayscale is very simple. We calculate the luminance of each pixel, aggregating its RGB values into one number, with the following equation: luminance = red * 0.2989 + green * 0.5870 + blue * 0.1140

This allows to squash three color channels into one (grayscale).

How

No assumptions are made. You don’t need to be a C++ expert, a compiler ninja, or an accelerator Jedi. The process we’ll follow is the opposite of complicated:

  1. Implement a solution in pure Python. CPU only.
  2. Make the solution CUDA-compatible. Which is a very fancy way of saying: turn the Python function into a for-loop, so that each iteration can (in principle) run parallelly.
  3. Ask ChatGPT to convert #2 into a CUDA kernel. Fix GPT mistakes. Only a few errors here and there.
  4. Compile the kernel with PyTorch.
  5. Run the kernel on the GPU and enjoy massive speedups compared to the dummy CPU solution.

The basics

First things first. Let’s go through some QA.

What is CUDA and why does it matter?

CUDA (Compute Unified Device Architecture) is a proprietary and closed-sourceparallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for general-purpose processing, an approach called general-purpose computing on GPUs (GPGPU). CUDA is a software layer that gives direct access to the GPU’s virtual instruction set and parallel computational elements for the execution of compute kernels.[1] […] CUDA was created by Nvidia.[4]

https://en.wikipedia.org/wiki/CUDA

The reason it matters (to us Data Scientists) is that without it there would be pretty much no Deep Learning. If you have ever tried training anything that is not a ResNet18 on a CPU, you know that a GPU is simply the only way to go. Whoever says GPU, says NVIDIA, and whoever says NVIDIA, says CUDA.

What is a CUDA kernel?

In computing, a compute kernel is a routine compiled for high throughput accelerators (such as graphics processing units (GPUs), […], separate from but used by a main program (typically running on a central processing unit). […] Compute kernels roughly correspond to inner loops when implementing algorithms in traditional languages

https://en.wikipedia.org/wiki/Compute_kernel

The RGB-2-grayscale (or matrix multiplication) kernel is nothing else than a piece of code executing the image conversion (or matrix multiplication) on a GPU at massive speed.

Why is CUDA fast?

A GPU is organized in a very different way compared to a CPU. The following are its main “components”:

1. Streaming Multiprocessors (SMs): In NVIDIA GPUs, SMs are the fundamental units of execution. Each SM can execute multiple threads concurrently.

2. Thread Blocks: A thread block is a group of threads that can cooperate among themselves through shared memory and synchronization. All threads in a block are executed on the same SM. This means they can share resources such as shared memory and can synchronize their execution with each other.

3. Shared Memory: Shared memory is a small memory space on the GPU that is shared among the threads in a block. It is much faster than global memory (the main GPU memory), but it is also limited in size. Threads in the same block can use shared memory to share data with each other efficiently.

https://github.com/cuda-mode/lectures/blob/main/lecture3/pmpp.ipynb

Jeremy goes ahead by highlighting some quick specs of the RTX 3090 (a very good NVIDIA GPU)

– The RTX 3090, based on the Ampere architecture, has 82 SMs.

– Each SM in GA10x GPUs contain 128 CUDA Cores, four third-generation Tensor Cores, a 256 KB Register File, and 128 KB of L1/Shared Memory

– In CUDA, all threads in a block have the potential to run concurrently. However, the actual concurrency depends on the number of CUDA cores per SM and the resources required by the threads.

https://github.com/cuda-mode/lectures/blob/main/lecture3/pmpp.ipynb

This means the RTX 3090 has roughly 82*128 ~ 10k CUDA cores. Each one of them can run an independent thread, and threads can run in parallel. Imagine running a sequential piece of code in parallel on 10k threads at the same time. In this lecture, we are going to exploit this massive parallelization, whereas, in lecture 5, Jeremy throws (the much trickier) memory usage into the mix.

RGB 2 grayscale

Slow Python implementation

The dummy Python function to convert a color image to grayscale is simple👇. Below you can find the code and a slide showing the exact operations for a 4x4 toy image. If you are wondering why we are flattening the 3x4x4 tensor into a one-dimensional, keep reading until the CUDA section.

def rgb2grey_py(x, verbose=False):
    c,h,w = x.shape
    if verbose: print(f"Image of shape {c}x{h}x{w}")
    n = h*w
    if verbose: print(f"Number of pixels: {n}")
    x = x.flatten()
    if verbose: print(f"Image flattened to array of shape {x.shape}")
    res = torch.empty(n, dtype=x.dtype, device=x.device)
    for i in range(n): 
        red = x[i]
        green = x[i+n]
        blue = x[i+2*n]
        res[i] = 0.2989*red + 0.5870*green + 0.1140*blue
        if verbose: print(f"Pixel {i} = 0.2989 * {red} + 0.5870 * {green} + 0.1140 * {blue} = {res[i]}")
    return res.view(h,w)

Running this dummy solution on CPU on the original puppy means executing the for-loop 1330 x 1990 ~ 2.5M times. I tried, and was obliged to kill the jupyter cell after 5 minutes 😱. Our only hope is to massively downsize. 32k pixels take 1.1 seconds. Let’s make this faster.

Still slow Python… but more CUDA-like

First version

We know there are a ton of CUDA cores in a GPU. How do we use them all? We have to structure our algorithm in a way that allows it to run the same piece of code in parallel, in an isolated manner with respect to all the other executions. We need an atomic operation. A way of doing so is the following:

def run_kernel(f, times, *args):
    for i in range(times): f(i, *args)

def rgb2grey_k(i, x, out, n):
    red, green, blue = x[i], x[i+n], x[i+2*n]
    out[i] = 0.2989*red + 0.5870*green + 0.1140*blue

def rgb2grey_pyk(x):
    c,h,w = x.shape
    n = h*w
    x = x.flatten()
    res = torch.empty(n, dtype=x.dtype, device=x.device)
    run_kernel(rgb2grey_k, h*w, x, res, n)
    return res.view(h,w)

Notice how we refactored the previous implementation. We took the operation inside the for-loop (res[i] = 0.2989red + 0.5870green + 0.1140*blue) and we stuck it into a separate function (rgb2grey_k).

A kernel.

The result is the same, but now we have rgb2grey_k running on each pixel, independently of each other. In other words, rgb2grey_k does not need to know anything about any other pixel except for the one it is working on. We invoke it n_pixels times. Note also how the function is not returning. It is “changing” the value of the memory allocated to the image (out[i]). The flattened array to be more precise.

Refactoring with threads and blocks

🚨 Note 🚨: this paragraph is very dense and it references following sections of the post. Feel free to get back to it while reading. It will hopefully make sense eventually.

Now that we have this working solution, we can push it a step further, by refactoring it once again as if it was running on GPU. To do so we have to introduce (again) blocks and threads. Before that, a couple of important ideas though.

A key concept to digest is the following:

  • 👉 What CUDA cares about is the index of the “operation” it is executing. What is that index pointing at in memory?
  • 👉 Memory is linear. CUDA doesn’t know how to operate on multi-dimensional arrays. It needs a 1D vector. That’s why we flattened the image out, and why we index the input and the output grayscale image with a single integer.

🚨I know what you are thinking: “Wait, what? We are processing an image. Wouldn’t it be much simpler (and logical) to use double indexing, e.g. image[row, col] = kernel_output? You said we’d be implementing matrix multiplication later on. Are you telling me we’ll have to go through this 1D madness even for matmul?”

Yes. But that’s the thing. Don’t be fooled (like I was) into thinking that CUDA supports 1D objects only. CUDA creates a list of indexes for you. That list is, by definition, one-dimensional (that’s how things are stored in memory). So you have to index it with a single integer (this is why we use out[i] in the RGB-2-grayscale example and why we’ll use out[row*w+col] instead of out[row, col] in the matmul example). CUDA is perfectly capable of handling 1, 2, and even 3D objects. It does so by exposing 3-element-tuples (dim3) allowing to access elements of a square or a cube, by simply calling x, y, and z attributes on them. Being able to access the different axes of a cube (assuming you have a 3D shape) is very handy. It means you can point at an element of it in an intuitive way (this is pure convenience; flattening a cube works perfectly too). Once you have that, you still have to find a way to convert x, y, and z into a single integer though. Because CUDA needs it to index the flattened version of your input/output.🚨

I asked GPT4 about this. Notice the subtility. CUDA can process objects of all shapes. It still “linearizes” them internally.

Ok, how does this “index handling” work? CUDA splits the list of indexes we discussed above into blocks. And then, within each block, into threads. This means we are not looping through pixels (bear with the image example). We are looping through blocks and threads. Inside each thread, we’ll know the index we are processing by multiplying block_id.x by block_dimension.x and adding up thread_id.x.

Image from PMPP

I hear you: “Where did y and z go? You said we’d have 3-element-tuples?”. In our implementation of RGB-2-grayscale we decided to explicitly flatten the image and operate on a 1D vector directly. Which is why we just have the x axis. y and z are set to 1 automatically by CUDA. 👉 As a matter of fact, we can totally rewrite our 1D implementation using a more “standard” 2D-indexing strategy (see the Comparison with a 2D CUDA kernel section below). CUDA perfectly allows to do so. A little messier IMO. But I’ll let you judge.

How many threads can we use? 256 is a good first guess. How about blocks? If we have fixed the number of threads, then the number of blocks is ceil(number of pixels/number of threads). ceil means we have to make sure that, while iterating, we don’t go over the number of total pixels, given the number of pixels might not an exact multiple of 256 (that’s the if i<n condition in the script below).

A lot to unpack here. Let’s go over the same 4x4 toy image example we also used before to visualize this mess.

def blk_kernel(f, blocks, threads, *args):
    for i in range(blocks):
        for j in range(threads):
            f(i, j, threads, *args)

def rgb2grey_bk(blockidx, threadidx, blockdim, x, out, n, verbose):
    i = blockidx*blockdim + threadidx
    if i<n: 
        red, green, blue = x[i], x[i+n], x[i+2*n]
        out[i] = 0.2989*red + 0.5870*green + 0.1140*blue
        if verbose: print(f"Pixel i: {i}, blockidx: {blockidx}, blockdim: {blockdim}, threadidx: {threadidx}, n: {n}, 0.2989 * {green} + 0.5870 * {red} + 0.1140 * {blue} = {out[i]}")

def rgb2grey_pybk(x, threads, verbose=False):
    c,h,w = x.shape
    print(f"Image of shape {c}x{h}x{w}")
    n = h*w
    if verbose: print(f"Number of pixels: {n}")
    x = x.flatten()
    print(f"Image flattened to array of shape {x.shape}")
    res = torch.empty(n, dtype=x.dtype, device=x.device)
    blocks = int(math.ceil(h*w/threads))
    print(f"Threads: {threads}, Blocks: {blocks}")
    blk_kernel(rgb2grey_bk, blocks, threads, x, res, n, verbose)
    return res.view(h,w)

The 1D CUDA kernel

Some basic setup first.

  • os.environ['CUDA_LAUNCH_BLOCKING']='1': this basically means that each CUDA kernel launch will block the CPU until the kernel finishes execution on the GPU. This synchronous behavior can be useful for debugging purposes, as it makes it easier to pinpoint the exact location of errors or unexpected behavior in CUDA applications. When a kernel fails or produces incorrect results, having a synchronous launch means that the error occurs exactly at the point in the source code where the kernel is launched, rather than at some later point, making it easier to debug. The price to pay is significant performance degradation though, so better not to use this in production.
  • We also need two additional modules. ninja is the first and it’s a build tool that is required by PyTorch to compile C++ CUDA code. wurlitzer (which needs to be loaded as an extension in jupyter) makes it possible to add print statements to the kernel and make them work in a notebook.

All right. We have got some CUDA code. How do we use it from Python? The answer is that PyTorch ships with a very handy function called load_inline (from torch.utils.cpp_extension) that lets you pass in a list of any of the CUDA code and C++ code that you want to compile, any functions in that C++ you want to make available to PyTorch, and it will go and compile all of that, to turn it into a Python module available right away 🤯.

Let’s delve into the CUDA code we are going to use.

#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}

We first start with a ☝️ reusable starter pack that we’ll make sure to prepend to any kernel. This snippet contains:

  • headers giving us access to torch, io and CUDA exceptions
  • 3 macros: the first is CHECK_CUDA that makes sure the tensor is on GPU. The second is CHECK_CONTIGUOUS that makes sure the tensor is contiguous, e.g. that it is not split into non-adjacent memory pieces. The third is CHECK_INPUT that executes the first 2 one after the after.
  • the definition of cdiv: ceiling division, e.g. way to figure out the number of blocks given the number of pixels and threads.

Now we actually have to write the CUDA kernel. How do we do that exactly? We ask GPT4 to convert python code to equivalent C++ code.

This👇 is the final working version of the CUDA kernel. I am keeping the naming and the structure from Jeremy’s notebook (so without the red, green and blue variables)

__global__ void rgb_to_grayscale_kernel(unsigned char* x, unsigned char* out, int n) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    if (i<n) out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n];
}

Copilot got it almost 100% right. Only notable thing to fix is unsigned char * instead of float, e.g. we want an array of uint8. That’s the 1D flattened image we are looping through. void means we are not returning anything, and __global__ is a special thing telling CUDA that the compiled code can be called from CPU (Host) and GPU (Device) and executed on GPU only (Device).

Image from PMPP

How do we call the kernel? Like so 👇

__global__ void rgb_to_grayscale_kernel(unsigned char* x, unsigned char* out, int n) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    if (i<n) out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n];
}

torch::Tensor rgb_to_grayscale(torch::Tensor input) {
    CHECK_INPUT(input);
    int h = input.size(1);
    int w = input.size(2);
    printf("h*w: %d*%d\n", h, w);
    auto output = torch::empty({h,w}, input.options());
    int threads = 256;
    rgb_to_grayscale_kernel<<<cdiv(w*h,threads), threads>>>(
        input.data_ptr<unsigned char>(), output.data_ptr<unsigned char>(), w*h);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

In Python we used to invoke the blk_kernel function, passing the rgb2grey_bk kernel, plus threads, blocks, and all the other arguments. In CUDA there is no special function. There is special syntax. We say what’s the kernel to call (rgb_to_grayscale_kernel) and then we open and close triple angle brackets, passing in between, among several possible things, at least the number of blocks (ceiling division between number of pixels and threads) and threads (<<<cdiv(w*h,threads), threads>>>). Then, as arguments to the kernel, we pass input and output tensor and number of pixels (w*h). We also have to make sure that the tensors are converted to uint8 pointers before sending them to the kernel, which is why we invoke the data_ptr<unsigned char>() method on them.

After we call the kernel, if there is an error we won’t necessarily be informed, so we have to add the C10_CUDA_KERNEL_LAUNCH_CHECK() macro provided by PyTorch to make sure errors are caught and handled properly.

We are ready to call our load_inline PyTorch function and compile the CUDA kernel. To do so, we are just missing cpp_src, aka the header of the C++ function we use to invoke the kernel (or the signature of the function with typed input and output), and a list of all the functions we want to make available to the outside world, aka rgb_to_grayscale.

With this info, PyTorch compiles the kernel and makes it available as a Python module. Which means we can invoke it on our entire (non downsized) input image, and process 1330 x 1990 ~ 2.5M pixels in 3.2 ms 🏃🏻💨. Notice how we are not timing the operation of moving the tensor to GPU and making it contiguous. We are timing the step of moving the output back to CPU though. That’s because we want to force the CUDA code to run until completion, syncronise results and then put back together the output.

We have successfully created our first fully functioning CUDA kernel! 🎉

Comparison with a 2D CUDA kernel

You can see here what I was referring to at the beginning of the Refactoring with threads and blocks section (go back and read that part of the post again. It’d be good timing). I am dropping the 2D version of the RGB-2-grayscale CUDA kernel. We’ll briefly go through this 2D grid logic for matrix multiplication later. Regardless, you can already see how the indexing differs between the two versions. We access blockIdx.x and blockIdx.y (the same applies to blockDim and threadIdx) in the 2D context, whereas only blockIdx.x in the 1D case. The logic is very similar though.

2D (aka using 2 indexes and iterating over the image, while inputs/outputs are still flattened)…
__global__ void rgb_to_grayscale_kernel(unsigned char* x, unsigned char* out, int w, int h) {
    int c = blockIdx.x*blockDim.x + threadIdx.x;
    int r = blockIdx.y*blockDim.y + threadIdx.y;

    if (c<w && r<h) {
        int i = r*w + c;
        int n = h*w;
        out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n];
    }
}

torch::Tensor rgb_to_grayscale(torch::Tensor input) {
    CHECK_INPUT(input);
    int h = input.size(1);
    int w = input.size(2);
    torch::Tensor output = torch::empty({h,w}, input.options());
    dim3 tpb(16,16); #16x16 = 256, so that is the same threads as in our 1D example
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    rgb_to_grayscale_kernel<<<blocks, tpb>>>(
        input.data_ptr<unsigned char>(), output.data_ptr<unsigned char>(), w, h);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}
… vs 1D (what we just saw, flattening the image into a 1D array)
__global__ void rgb_to_grayscale_kernel(unsigned char* x, unsigned char* out, int n) {
    int i = blockIdx.x*blockDim.x + threadIdx.x;
    if (i<n) out[i] = 0.2989*x[i] + 0.5870*x[i+n] + 0.1140*x[i+2*n];
}

torch::Tensor rgb_to_grayscale(torch::Tensor input) {
    CHECK_INPUT(input);
    int h = input.size(1);
    int w = input.size(2);
    printf("h*w: %d*%d\n", h, w);
    auto output = torch::empty({h,w}, input.options());
    int threads = 256;
    rgb_to_grayscale_kernel<<<cdiv(w*h,threads), threads>>>(
        input.data_ptr<unsigned char>(), output.data_ptr<unsigned char>(), w*h);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

Matrix multiplication

We go into a lot less detail here compared to the previous example. The CUDA logic is the same, just extended to 2 dimensions. Instead of having an integer index (x), we have a tuple index (x, y). A pair of indexes to refer to both blocks and threads.

One thing at a time though. First an illustration from the PMPP as a reminder of how matrix multiplication works👇

(Image from PMPP) Multiplication of 2 matrices M and N. Each element of the output matrix P is the dot product of each row of M with each column of N.

We follow the same procedure as before:

  1. Write a dummy slow Python version
  2. Refactor the for-loop into a separate Python kernel and add a dedicated function calling that Python kernel
  3. Ask GPT4 to convert the Python code into CUDA-compatible code
  4. Compile in PyTorch and enjoy the speedup

👇 is the Python implementation. It’s interesting to see the transition from the dummy (but quite readable) matmul function at step 1 to the way more complex CUDA-like version at step 2. The fact we operate (because we choose to do so for convenience) on a 2D grid means blocks and threads are tuples and not integers. So a block is identified as (blockidx.x, blockidx.y) and a thread as (threadidx.x, threadidx.y). That’s why we define threads-per-block as tpb = ns(x=16,y=16). This also means that we have 4 nested for-loops in practice, as we are iterating over all blocks and threads (the horrible mess happening inside blk_kernel2d). Notice also how, due to the ceiling division, we now have a guardrail on both rows and columns (if (row>=h or col>=w)).

Let’s rephrase the above to make it clearer. Our atomic operation (the one that gets parallelized and that is getting executed by the kernel) is the part that does the dot-product between each row and column of the two matrices. Even if, inside it, being a dot-product, it also has a for loop. That doesn’t matter. CUDA gives us an index. We decide what to do with it. We decided to allocate to each thread an entire dot-product.

Another note: see the dim3 object (3-integers tuple) in action here too. tpb = ns(x=16,y=16) is the same thing as tpb = ns(x=16,y=16,z=1). Same for (blockidx.x, blockidx.y, blockidx.z=1).

# STEP 1: DUMMY PYTHON FUNCTION EXECUTING MATMUL
def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            for k in range(ac): c[i,j] += a[i,k] * b[k,j]
    return c

# STEP 2: REFACTOR THE FOR-LOOP TO CREATE...
# ... A PYTHON CUDA-LIKE VERSION OF THE DUMMY FUNCTION USING BLOCKS AND THREADS...
def matmul_bk(blockidx, threadidx, blockdim, m, n, out, h, w, k):
    row = blockidx.y*blockdim.y + threadidx.y
    col = blockidx.x*blockdim.x + threadidx.x
    
    if (row>=h or col>=w): return
    o = 0.
    for i in range(k): 
        o += m[row*k+i] * n[i*w+col]
    out[row*w+col] = o

#... AND DEDICATED FUNCTION(S) EXECUTING THE KERNEL...
def blk_kernel2d(f, blocks, threads, *args):
    for i0 in range(blocks.y):
        for i1 in range(blocks.x):
            for j0 in range(threads.y):
                for j1 in range(threads.x): 
                    f(ns(x=i1,y=i0), ns(x=j1,y=j0), threads, *args)

#...THIS ONE TOO ACTUALLY
def matmul_2d(m, n):
    h,k  = m.shape
    k2,w = n.shape
    assert k==k2, "Size mismatch!"
    output = torch.zeros(h, w, dtype=m.dtype)
    tpb = ns(x=16,y=16) #16x16 = 256, so that is the same threads as in our 1D example
    blocks = ns(x=math.ceil(w/tpb.x), y=math.ceil(h/tpb.y))
    blk_kernel2d(matmul_bk, blocks, tpb,
                 m.flatten(), n.flatten(), output.flatten(), h, w, k)
    return output

Before diving into the mechanics of the code, let’s remind ourselves how multiplication works on 2 toy matrices m1 (shape 2×3) and m2 (shape 3×4). Below is the animated multiplication between them.

Visual matrix multiplication of the 2 dummy matrices m1 and m2 we are using as examples for illustration purposes

With this in mind, here is what happens when we invoke matmul_2d on m1 and m2. The Python code was completely mind-bending to me, and I didn’t fully grasp it until I created this slide. Notice (once again) how we are indexing the 2D image (x and y) but then, in practice, the logic is still executing on a flattened version of the 2 matrices (and the output). The visualization (and the Python logs) show it clearly. I have arbitrarily set threads as a (threadidx.x=2, threadidx.y=2) tuple (a 2×2 grid or 4 in total, aka tpb = ns(x=2,y=2)). This means, that we’ll have 2 blocks only, given we have 8 elements to process to create the output matrix (res of shape 2x4).

Step-by-step walkthrough of the operations needed to obtain the first 2 elements of the output matrix res, e.g. dot product of the first row of m1 with the first 2 columns of m2.
Logs for the entire matrix multiplication for reference. The previous slide shows just the first 2 elements.

We then ask GPT4 to convert the Python code to CUDA. Same as we did earlier.👇 is the result (95% right at a first Copilot attempt). Once again, the real effort is understanding and digesting the Python implementation. Once done, reading the below is basically reading Python.

__global__ void matmul_k(float* m, float* n, float* out, int h, int w, int k) {
    int r = blockIdx.y*blockDim.y + threadIdx.y;
    int c = blockIdx.x*blockDim.x + threadIdx.x;

    if (r>=h || c>=w) return;
    float o = 0;
    for (int i = 0; i<k; ++i) o += m[r*k+i] * n[i*w+c];
    out[r*w+c] = o;
}

torch::Tensor matmul(torch::Tensor m, torch::Tensor n) {
    CHECK_INPUT(m); CHECK_INPUT(n);
    int h = m.size(0);
    int w = n.size(1);
    int k = m.size(1);
    TORCH_CHECK(k==n.size(0), "Size mismatch!");
    auto output = torch::zeros({h, w}, m.options());

    dim3 tpb(16,16);
    dim3 blocks(cdiv(w, tpb.x), cdiv(h, tpb.y));
    matmul_k<<<blocks, tpb>>>(
        m.data_ptr<float>(), n.data_ptr<float>(), output.data_ptr<float>(), h, w, k);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output;
}

Performance

Let’s skip the PyTorch load_inline compiling thing (we know how it works) and briefly discuss performance

WhatWhereHow many opsHow fast
Python (3 nested loops)CPU39k1s 🐢
Python (broadcasting)CPU392M1s 🐢 💨
Our CUDAGPU392M6ms 🐇
PyTorchGPU392M2ms 🐇 💨
Compared performance of matmul
  • We executed the original Python version of matmul on downsized matrices (39k ops in total). It would have been too slow on the fully blown matrices Jeremy is using from the MNIST dataset (392M ops in total).
  • The super dummy Python version (no broadcasting in numpy) on CPU is a no-go. Simply too 🐢.
  • Our CUDA implementation is already 🐇
  • An interesting question is: why PyTorch on GPU is 3x faster than our CUDA version? Because it’s taking advantage of shared memory. Jeremy discusses memory optimization in his second lecture (5th in the reading group), titled Going Further with CUDA for Python Programmers.

Shared memory is a small memory space on the GPU that is shared among the threads in a block. It is much faster than global memory (the main GPU memory), but it is also limited in size. Threads in the same block can use shared memory to share data with each other efficiently.

https://github.com/cuda-mode/lectures/blob/main/lecture3/pmpp.ipynb

Conclusion

Overall I found Jeremy’s videos as eye-opening as usual. The fact that we could write CUDA kernels from scratch was 🤯 to me. Especially the fact that we are doing it in a Jupyter notebook, and that I don’t know C++ at all. This results alone was way more than I could have hoped for.

Having said that, I am not fully sure I’ll ever venture into writing a CUDA kernel myself. The Going Further with CUDA for Python Programmers lecture, in which Jeremy goes into memory optimization (on top of simply parallelizing code across blocks and threads) was dreadful to me. Of course you don’t necessarily have to stretch over to that point, but still a reminder that things can get very complex very fast. Also, the 1D/2D indexing framework across blocks and threads is mind-bending. I found it hard to wrap my head around and I am not convinced I’d be able to keep my code under control for a use case more complex than a “simple” matmul. Anyway, great learning experience overall.

Happy hacking!

Jeremy’s feedback

Something absolutely 🤯 to me happened a couple of hours after publishing the first version of this post. Jeremy himself reviewed it. Two thoughts here:

  1. This might look like nothing to you but imagine for a second someone of Jeremy’s caliber getting out of bed in the morning, checking his X feed, finding (I would assume) several dozens (hundreds?) at-mentions, and deciding to sit down, read and provide thorough feedback on a piece of content written by a complete stranger. How cool is this? We need more heroes like him. He is the reason I am here writing about CUDA after all, as I started my Deep Learning journey years ago with his very first fastai course. Keep rocking Jeremy!
  2. When in doubt, write about it. And post it. I was not quite sure I had fully understood all the concepts presented in the lecture (spoiler: I had not). But at some point, I just felt the best thing to do was to summarise what I had digested so far and get it out there. Hoping someone would provide feedback and maybe constructive criticism to patch up my inconsistencies. Turns out that “someone” was Jeremy himself. TLDR: just ship it.

Leave a Reply

Your email address will not be published. Required fields are marked *

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading