Skip to content

Fine-tune a DeepFake video classifier: PyTorchVideo, Lightning, W&B, and Amazon SageMaker in action

Reading Time: 9 minutes

Note: link to code on GitHub.

Context

The trigger that originally sparked the conception of this post was the desire to explore the newly (back in May) released FAIR’s PyTorchVideo library. Since then, the scope of the exploration has increased significantly, so here is what the following writeup is about:

  • Fine-tune a classifier to detect deepfake videos
  • using the Amazon SageMaker SDK
  • and the PyTorchVideo framework
  • in conjuntion with PyTorch Lightning
  • with logging taken care of by Weights&Biases

Data

As always, Kaggle came to the rescue. The Deepfake Detection Challenge ran 2 years ago and came with a massive dataset: 470 GB in total. I definitely didn’t need all of that for my experiment, so I downloaded only a chunk (0.5.zip, worth 9.4 GB) and furtherly subsampled it, creating a balanced dataset with all REAL videos (360) and an equal number of randomly selected FAKEs.

Spoiler alert: the original full HD videos are too big to be fed to a neural network. Even disregarding the time component (e.g. the fact that we have to input multiple frames in sequence), a single 1928 x 1080 image would be too big to fit on a GPU (in a batch). We have to resize each frame. This can be totally done as part of the standard transforms pipeline images go through before getting into a model (ImageNet-normalize, cropping, various augmentations, etc). Nevertheless, my experiments with pytorchvideo and torchvision transforms suggest that this step is excruciatingly slow. Whenever I used the original dataset, data preparation was a major bottleneck in the training pipeline, hence I checked the Kaggle forums and an interesting approach came up.

If you take a look at the provided FAKE clips, it is rather obvious to spot that the deepfake manipulation had occurred in the facial region of the displayed individual. Given the original MP4s are 1928 x 1080 pixels, and the face occupies only a small region of the frame, we are feeding the model with a lot of useless information. On top of it, going back to the transforms pipeline, we have to resize the original frames anyway as they are too big to fit on GPU, hence we are losing useful signal (the face which is already small gets downsized even further). Therefore, I shamelessly copied the approach showcased in this Kaggle kernel. The idea is to use MTCNN (shipped with the facenet_pytorch package) to detect faces every N frames of each clip and stack them together into a new MP4 file. The result is a video focused on our region of interest only (the face) and obviously much smaller and lighter: 10 vs 30 original FPS and 256 x 256 vs 1928 x 1080 resolution.

Below you can see how this looks like in the original VS face-cropped clips and here is the link to the notebook implementing the data preparation logic.

Original FAKE video: 30 FPS and 1928 x 1080 big
Same video as above with cropped face only: 10 FPS and 256 x 256. Those are the clips I trained the model on.

PyTorchVideo

Video modeling is notoriously complicated. In most cases, what you end up doing is splitting the clip into its frames and applying standard single-image computer vision algorithms. A naive strategy, which works surprisingly well and delivers a very solid baseline. There are of course more advanced models, which were built with a temporal component in mind and that can handle chunks of videos as a whole. They require quite a bit of custom work in terms of data preparation though. Most importantly it isn’t easy to find pre-trained models to fine-tune, which is what the large majority of DL Engineers need.

Given this premise, you can imagine my excitement when PyTorchVideo (PTV) was announced.

PyTorchVideo is a deep learning library for research and applications in video understanding. It provides easy-to-use, efficient, and reproducible implementations of state-of-the-art video models, data sets, transforms, and tools in PyTorch.

https://ai.facebook.com/blog/pytorchvideo-a-deep-learning-library-for-video-understanding/

I decided to start small and test the classification options coming with the framework. PTV offers object detection too but that was out of scope at this stage.

Overall, the tutorials do a pretty good job at getting you started. Nevertheless, my fine-tuning use case was covered in neither of the provided docs: training a PTV model from scratch and running inference with a pretrained PTV model from Torch Hub. It is quite easy to hack a solution together though.

The pre-trained model

The first step is to choose the model you are interested in from the PTV’s model zoo. I picked the X3D family, given it is the smallest in size (M of params), allowing a (hopefully) faster prototyping cycle. Once selected the model, the two following points must be addressed:

  1. The preprocessing pipeline your video data must undergo before being passed into the model.
  2. How to change the last layer of the network to reflect the number of classes you are dealing with.

This is crucial as we are using a model pretrained on a dataset different than ours (the whole point of transfer learning actually). the X3D networks were trained on the Kinetics-400 dataset. I need to adopt the same data transforms the authors used and edit the model’s head, given I have 2 classes (deepfake VS real) compared to the 400 in Kinetics.

As for #1 you can find all the relevant info on Torch Hub. Here the transforms pipeline for X3D. As another example, here the page for SlowFast, another common video classifier.

As for #2, the following snippet does the job. It is just a matter of extracting the number of out_features from layers[-2] which will be the in_features for the last layer (layers[-1]). 2048 for X3D. We need this number to define the new nn.Linear layer with 2 output classes.

# 1. Load pre-trained network:
model = torch.hub.load("facebookresearch/pytorchvideo:main", model="x3d_s", pretrained=True)
layers = list(model.blocks.children())
_layers = layers[:-1]
self.feature_extractor = nn.Sequential(*_layers)
# 2. Classifier:
self.fc = layers[-1]
self.fc.proj = nn.Linear(in_features=2048, out_features=2, bias=True)

Fine tuning a model in PyTorch Lightning

I opted for PyTorch Lightning (PL) to train my model. There are two ways of fine-tuning a network in PL:

  1. Using the Lightning Flash API, which ships a fine-tuning capability offering a couple of interesting training strategies. You can define your backbone and head blocks, decide which part to freeze and for how many epochs.
  2. Using plain PL, subclassing the BaseFinetuning classback, as proposed here. This is the route I went for.

Everything happens here, inside the MilestonesFinetuning class, and is illustrated in the above slide. The idea is to:

  1. Freeze the model’s feature_extractor, before the training begins. feature_extractor is defined here and includes everything except the last layer (fc), re-initialised from scratch to output 2 classes instead of 400.
  2. When configure_optimizers is invoked (after step #1), the trainable_parameters it will find are only the ones from the top layer, as the rest is frozen.
  3. Train fc up until a milestone epoch is hit. Concretely this means that (using the default value of 5) the top layer is trained for the first 5 epochs at the learning rate passed to the optimizer. Once again, using the default value, this would be lr=1e-3. In the meantime the feature_extractor‘s weights are left untouched.
  4. When we reach the milestone epoch, 4 things happen:
    1. Inside the finetune callback, (unfreeze_and_add_param_group) we unfreeze the top unfreeze_top_layers of the feature_extractor
    1. … and add them as a second param group to the optimiser. feature_extractor[-self.unfreeze_top_layers:] defines this new group. Using the default value of 3 for unfreeze_top_layers we split the network in 3 parts. fc, the top 3 layers of the feature_extractor and the rest. unfreeze_top_layers is a parameter so the depth of this layers’ group is something we can experiment with.
    1. The MultiStepLR scheduler is also invoked. It reduces the learning rate of fc (the first param group) by a factor of 10 from 1e-3 to 1e-4.
    2. unfreeze_and_add_param_group also allocates to the added param group a new learning rate, 10x smaller than the one for the head, so 1e-4/10=1e-5 (10 is the default denom_lr argument). This reflects the fine-tuning strategy: start training the head with a reasonably high LR, then unfreeze the rest of the network (or part of it as in our case) and keep going with a lower LR, making sure the backbone always gets trained with a lower LR wrt the head (we don’t want to wreck weights of the deepest layers).

Logging to Weights & Biases (W&B)

This is incredibly easy. You just need to:

  1. create a W&B account
  2. create an API key (documentation)
  3. pip install wandb
  4. wandb login in a terminal to login to your account (this is when you paste the API key create at #2). Note: I mainly used scripts to train my models, which is what wandb login is for. If you are a Jupyter kinda guy, check this approach instead.
  5. the integration with PL is ridiculously simple. It is sufficient to define a WandbLogger and pass it to the pl.Trainer. Then invoke self.log(...) (where self is a pl.LightningModule) passing any metric you want to track in the dashboard. As an example, this is how I logged the validation loss and this is how I kept track of the learning rate instead. Note: I didn’t use the LearningRateMonitor callback, as I wanted more fined control on what to log and how. I have used it in the past though and it works great.
  6. The result is this W&B report showing the most interesting metrics and variables’ evolution over epochs.

Model performance

The few experiments I have run are summarised in the W&B report I linked earlier. This is by no means an exhaustive list of all possible tests to perform! I have just executed a few off the top of my head, following my gut feeling. Nevertheless, I manage to reach 0.9271 accuracy on the validation set, which is not bad!

As you can see I have tested:

  • 3 architectures from the X3D family: XS, S and M. XS comes on top (the smallest!).
  • Unfreezing the top 1 or 3 layers of the feature_extractor, e.g. unfreeze_top_layers. 3 works best.
  • Exploring a couple of initial learning rates. 1e-3 works best.
  • Playing with the batch size. 16 seems to rank the highest.

Something I didn’t check at all is the effect of data augmentation, e.g. brightness, contrast, flipping, etc. The results here are coming from a completely un-augmented dataset.

Validation accuracy
Validation loss

For reference, this is the script I used to train the above models on an EC2 machine (p3.2xlarge).

Training in Amazon SageMaker with a fully custom Docker image

To conclude this post I wanted to add a quick SageMaker twist to it. The models I discuss in the previous section have been trained with a standard python script on a GPU-powered EC2 instance. This is fine but it doesn’t really scale. Let’s see how we can achieve the same with Amazon SageMaker. Unlocking this functionality opens the door to the true power of the AWS cloud.

To do so we will go down the road of a fully custom Docker image. The reason behind this (a little painful) decision is simple. I took all the easy routes first and none worked.

Why fully custom?

The first strategy was to use a standard PyTorch Estimator. This consists in simply passing an entry point script together with its dependencies to a sagemaker.pytorch.PyTorch class (great post about this approach by AWS Hero Luca Bianchi here). SM takes care of wrapping everything on top of a pre-built Docker image. In my case, PyTorchVideo wouldn’t cooperate with the pre-installed SM dependencies, throwing all sorts of errors.

Therefore, I moved to the extending-a-SM-PyTorch-container approach. This offers more flexibility than using a standard image, as we get to mess up with the Dockerfile and install any additional dependency ourselves, gaining more control over the process. Spoiler alert: same issues as before. I am not going to list all of them in detail as, in all honesty, I don’t even recall everything I tried. In a nutshell, those were essentially about incompatible dependencies between the pre-built SM image and the libraries I had to install, namely PyTorchVideo and PyTorch Lightning.

Which brought me to the fully custom solution. The problem, I reckoned, would never go away as long as I kept using Amazon images as a base. I needed to write my Dockerfile from scratch. This is nice as you have complete control over what gets into your environment. AWS provides a tutorial on how to bring a custom container. I started from there but quickly realized it was not enough. To be clear, I managed to get my training job to run successfully on SM (e.g. python main.py), but I kept stumbling on another problem: the hyper-parameters I was providing to the Estimator class in main.py wouldn’t get passed along as arguments to the entry point script train.py. So basically, no matter what I did, the training job would always run with the default arguments. This had clearly something to do with SM not being able to communicate properly with my code base, e.g. something was wrong in the Docker image. Sean Morgan, an AWS ML SA, kindly pointed me to the solution. I was missing this line RUN pip3 install sagemaker-training from the Dockerfile. The SageMaker Training Toolkit is a library that makes an image compatible with the SM ecosystem allowing, among other things, to pass arguments to the entry point using hyperparameters. Bingo!

In practice, you have to:

  1. Prepare your Dockerfile and the code to wrap into it, namely the entry point script and its dependencies.
  2. Write the file to invoke your training job (main.py). Not to be included into the image. This is the script we use to run the container and trigger model training basically.
  3. Build and tag the custom Docker image locally. Then upload it to Amazon ECR. Here the script I used to do that.
  4. Fetch from ECR the ARN of the image, to be used here.
  5. Fetch the ARN of the IAM Role with the right permissions (SM + S3), to be used here.
  6. Get from W&B your personal API key, if you want to log to that platform. To be used here (this is how SM sets ENV variables).
  7. Run python main.py to trigger the remote SM training.

That’s it. Hopefully my journey is going to be helpful for someone else. Happy hacking!

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading