Skip to content

Fastai Course Part 2 2022: Understanding CallBacks

Reading Time: 9 minutes

Context

In October 2022 I started attending the new version of part 2 of the fastai course. In this one, Jeremy builds up a Deep Learning training and evaluation framework from scratch. Literally starting out from matrix multiplication and climbing all the way up to torch. It’s a somewhat simplified version of the fastai library itself actually. A smaller one, called miniai. He literally codes the entire thing from the ground up, so, as a Python enthusiast, it’s pure joy for the eyes and the ears. Not only because he walks students through the ML implications of the work, but, at least for myself, mostly because he shares all sorts of tips and tricks around the Python programming language. How do you build a flexible framework allowing to test hypotheses as quickly as possible? If you want to do that, you don’t just write code. You actually pause, think, design, and then write GOOD code. Expert-level code. It’s very likely that if I had to cover all the tips Jeremy pulled out of his magic hat, this post would be a short book, so I decided to pick the one that stuck with me so far: callbacks. Of course.

The Learner

In order to dig into callbacks, I’ll peel the onion of the Learner class introduced in notebook 9 here (look for the “Updated versions since the lesson” section). A Learner is an object that encapsulates everything we need to know about a model. The data, the loss function, the evaluation metrics, and, well, the model itself. It looks like this

class Learner():
    def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
        cbs = fc.L(cbs)
        fc.store_attr()

    def cb_ctx(self, nm): return _CbCtxInner(self, nm)
                
    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        with self.cb_ctx('epoch'):
            for self.iter,self.batch in enumerate(self.dl):
                with self.cb_ctx('batch'):
                    self.predict()
                    self.callback('after_predict')
                    self.get_loss()
                    self.callback('after_loss')
                    if self.training:
                        self.backward()
                        self.callback('after_backward')
                        self.step()
                        self.callback('after_step')
                        self.zero_grad()
    
    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        cbs = fc.L(cbs)
        # `add_cb` and `rm_cb` were added in lesson 18
        for cb in cbs: self.cbs.append(cb)
        try:
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
            with self.cb_ctx('fit'):
                for self.epoch in self.epochs:
                    if train: self.one_epoch(True)
                    if valid: torch.no_grad()(self.one_epoch)(False)
        finally:
            for cb in cbs: self.cbs.remove(cb)

    def __getattr__(self, name):
        if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
    
    @property
    def training(self): return self.model.training

Here is how it works in practice. We instantiate an object and call fit

Screenshot 1

When initializing the Learner, we pass:

  • model: this is the PyTorch model we are going to train
  • dls: these are the DataLoaders wrapping the train and validation sets (we are using the FashionMNIST dataset downloaded from HuggingFace)
  • F.cross_entropy: that’s the loss function. We are running a multi-class classification task (10 classes of clothing to be predicted) so we use a standard CE loss
  • lr: the learning rate to be passed to the optimizer
  • cbs: a list of callbacks

First things first. What are callbacks and why are they useful? A callback is a function that gets executed at specific locations in the code modifying attributes of the class it runs into. Let’s put it this way. Without callbacks, if you wanted to print “This is a great batch named x!” after loading a batch x and before running model(x), you’d have to do something like

x, y = batch
print("This is a great batch named x!")
pred = model(x)

Now, what if you want to also print “Btw, y is part of the batch too!”? Easy. I just add another print statement

x, y = batch
print("This is a great batch named x!")
print("Btw, y is part of the batch too!")
pred = model(x)

I know what you are thinking. What’s your point here? My point is that all of the above is not elegant. What we are doing is modifying a core component of the training loop (x, y = batch; pred = model(x)) with additional noisy code. Sooner rather than later, it’s going to become very hard to maintain, especially if the logic is more complex than just print statements (hopefully it is). What we would like to have instead is something like

x, y = batch
run_anything_you_want_after_loading_a_batch() # <-- this is a callback
pred = model(x)

Where run_anything_you_want_after_loading_a_batch is a function in charge of executing the needful after loading the batch and before feeding it to the model. Like our print statements above. Want to add another print statement? Do it inside run_anything_you_want_after_loading_a_batch. That’s a function defined elsewhere, so you don’t clutter the main training loop and you keep things clean. You hopefully see how this becomes even more critical and useful when you start thinking about how to run code before/after calculating the loss, before/after running the backward pass, before/after each epoch, etc. The training loop is going to turn into a mess if you don’t stop for a second and give some deep thought to how to tidy things up.

The Callback magic

Ok, this is cool in theory, but how are we going to make this happen in practice? This is what callbacks are for. This is what self.cb_ctx('something') and self.callback('after_something') do (see the Learner). This is what allowed us to keep the entire Learner class below 50 lines of code. Everything else happens outside. In separate callback classes. Let’s see how.

We start from the beginning. From the Learner initialization: learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)

The Learner.__init__ does nothing else than storing the attributes we pass to the constructor into self (e.g. so that we call self.lr across the class) and converting cbs from a python list into a fastcore (fc) L object. If you are not familiar with it, fc.L is basically a list on steroids, adding nice functionality to the standard python object.

Once initialized, we can call fit on it (learn.fit).

Inside fit we first append additional callbacks to the original list (if passed to the method). We then store the number of epochs and instantiate the optimizer (SGD by default). Then the callbacks magic starts to unfold.

Context managers

What is this with self.cb_ctx('fit')?

  • It invokes cb_ctx(self, nm) which in turn returns _CbCtxInner(self, nm).
  • _CbCtxInner is a custom-defined context manager. A context manager is a python tool making it possible to control what gets executed before and after the code snippet included in the with statement. A classic example is IO handling, e.g. with open(filename, "r+") as f: do_stuff(f), which opens and closes the file for us without us even realizing it. It does that before and after do_stuff.
  • In order to define a context manager from scratch we have to write a class with an __enter__ and an __exit__ method. Those are the methods being invoked before (__enter__) and after (__exit__) our code in the with statement. Let’s go check out _CbCtxInner
class _CbCtxInner:
    def __init__(self, outer, nm): self.outer,self.nm = outer,nm
    def __enter__(self): self.outer.callback(f'before_{self.nm}')
    def __exit__ (self, exc_type, exc_val, traceback):
        chk_exc = globals()[f'Cancel{self.nm.title()}Exception']
        try:
            if not exc_type: self.outer.callback(f'after_{self.nm}')
            return exc_type==chk_exc
        except chk_exc: pass
        finally: self.outer.callback(f'cleanup_{self.nm}')

The __enter__ method invokes self.outer.callback(f'before_{self.nm}'). The outer argument to the __init__ corresponds to self in the Learner class, e.g. outer refers to the Learner object. This means that self.outer.callback translates into learn.callback (line 45 of the Learner class). nm is instead 'fit', so we are basically invoking learn.callback('before_fit').

What does this do? It runs run_cbs(learn.cbs, 'before_fit', learn). Let’s check it out

def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None: method(learn)

run_cbs:

  1. loops through all the callbacks defined at the Learner level (sorted by order)
  2. checks if each callback has the requested method_fn (before_fit in this case) by using getattr
  3. if not, it does nothing. If yes, it applies the method to the Learner object (method(learn))

Going through the TrainCB(), DeviceCB() and metrics callbacks

The callbacks we had defined (screenshot 1) are cbs = [TrainCB(), DeviceCB(), metrics] (metrics = MetricsCB(accuracy=MulticlassAccuracy())). Here they are in code:

class Callback(): order = 0

class TrainCB(Callback):
    def __init__(self, n_inp=1): self.n_inp = n_inp
    def predict(self, learn): learn.preds = learn.model(*learn.batch[:self.n_inp])
    def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds, *learn.batch[self.n_inp:])
    def backward(self, learn): learn.loss.backward()
    def step(self, learn): learn.opt.step()
    def zero_grad(self, learn): learn.opt.zero_grad()

class DeviceCB(Callback):
    def __init__(self, device=def_device): fc.store_attr()
    def before_fit(self, learn):
        if hasattr(learn.model, 'to'): learn.model.to(self.device)
    def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)

class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)

    def after_batch(self, learn):
        x,y,*_ = to_cpu(learn.batch)
        for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

The three of them inherit from the Callback class without altering the order attribute. Therefore they all have an order of 0, meaning they will be accessed and executed in the same order they appear in the cbs list (TrainCB first, DeviceCB second, and MetricsCB third). The loop goes as follows:

  1. Does the TrainCB callback have a before_fit method? No. Do nothing and keep going.
  2. Does the DeviceCB callback have a before_fit method? Yes! Let’s invoke it on the Learner. What does it do? if hasattr(learn.model, 'to'): learn.model.to(self.device). It puts the model on def_device (cuda in my case, as I am running on a GPU machine).
  3. Does the MetricsCB callback have a before_fit method? Yes! Let’s invoke it on the Learner. What does it do? learn.metrics = self. It defines the metrics attribute of the Learner by pointing to itself (self, the MetricsCB class).

What happens next? The code inside the with self.cb_ctx('fit'): statement gets executed, e.g.

for self.epoch in self.epochs:
    if train: self.one_epoch(True)
    if valid: torch.no_grad()(self.one_epoch)(False)

Exiting the context manager

Before stepping into the for loop, let’s reflect on how the context manager exits. Once looped through all the epochs, the __exit__ method of the _CbCtxInner class kicks in. As you can see from the code, it looks for the CancelFitException within the globally defined variables and then it applies the same logic we explained before for before_fit. This time it applies it to after_fit. It loops through the callbacks, looks for an after_fitmethod, and if it finds it, it runs it. If at any point in time, the CancelFitException is raised, the loop exits. Or better, it doesn’t just throw an exception and errors out. The “exit” is handled gracefully. Let’s see how with an additional callback.

class SingleBatchCB(Callback):
    order = 1
    def after_batch(self, learn): raise CancelFitException()

Let’s assume we had added SingleBatchCBto cbs. In this case this callback would get invoked last (order = 1). Specifically, its after_batch method would get invoked after the after_batch method of all the other callbacks. When the runtime hits SingleBatchCB.after_batch, it raises a CancelFitException.

  1. This means the try/except statement within the __exit__ method of the context manager at the batch level catches the exception, stops iterating through batches, and moves a level up. At the epoch level.
  2. Here the after_epoch of each callback (if existent) is executed. The try/except statement within the __exit__ method of the context manager at the epoch level doesn’t throw an exception but it returns False (return exc_type==chk_exc → return CancelFitException==CancelEpochException). Which means that the epoch loop is broken too.
  3. We move to another level. The last and highest. The fit level. Here the after_fit method of each callback (if existent) is executed. The try/except statement within the __exit__ method of the context manager at the fit level doesn’t throw an exception and it returns True (return exc_type==chk_exc → return CancelFitException==CancelFitException), killing the fit.

Stop one second to appreciate the beauty of what we just achieved. Imagine how complex this logic would be, hadn’t we used context managers and the dynamic flexibility of callbacks. And we are only at the fit stage! We were about to dive into epochs. Here the approach is the same. with self.cb_ctx('epoch') encapsulates before_epoch and after_epoch operations, looping through the 3 callbacks we have defined, and executing the relevant methods. with self.cb_ctx('batch') works the same, at the batch level.

What’s this self.predict() thing?

Inside the batch operations, something unusual and very interesting happens. We invoke self.predict() (line 14 of the Learner class). Nowhere do we have a predict method defined though. Same holds, a couple of rows below, for get_loss, backward, step and zero_grad. The Learner class indeed does not have those methods. This is the point where the an AttributeError should be thrown. Unless we override the __getattr__ method, which we do at line 41.

def __getattr__(self, name):
    if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
    raise AttributeError(name)

What happens here is the following. When looking for any of the 'predict', 'get_loss', 'backward', 'step', and 'zero_grad' methods we invoke self.callback (via partial) passing the method in question as method_nm. So when we hit predict, we actually call self.callback('predict'). Which calls run_cbs. Which loops through all the callbacks the Learner has, checking (and executing if found) the predict method into each one of them. In our specific example, only the TrainCB callback has a predict method defined. It does what we expect it to do. Run the batch through the model and return predictions. Why are we going through all this trouble? Because, together with everything happening AROUND the training loop (this is how we introduced callbacks in the first place), now we have abstracted the training loop itself. Every bit of it.

Adding momentum in 3 lines of code

Why do we care? Look at how simple it becomes to implement SGD with momentum.

def zero_grad(self):
    with torch.no_grad():
        for p in self.model.parameters(): p.grad *= self.mom

We just have to change the zero_grad method, either inside TrainCB or inside a newly defined Learner class (see TrainLearner and MomentumLearner). How cool is that? That’s literally the only change we have to make to get momentum to work.

All right, there is quite a bit of info to digest here. What I recommend doing is cloning the repo, opening up a notebook, and starting experimenting with miniai. The implementation is transparent and concise enough to allow us to edit parts of the code and immediately see what is happening. Give it a go and familiarize yourself with the magic of callbacks! It’s worth it!

Leave a Reply

Your email address will not be published. Required fields are marked *

Discover more from

Subscribe now to keep reading and get access to the full archive.

Continue reading