Context
In October 2022 I started attending the new version of part 2 of the fastai
course. In this one, Jeremy builds up a Deep Learning training and evaluation framework from scratch. Literally starting out from matrix multiplication and climbing all the way up to torch
. It’s a somewhat simplified version of the fastai
library itself actually. A smaller one, called miniai
. He literally codes the entire thing from the ground up, so, as a Python enthusiast, it’s pure joy for the eyes and the ears. Not only because he walks students through the ML implications of the work, but, at least for myself, mostly because he shares all sorts of tips and tricks around the Python programming language. How do you build a flexible framework allowing to test hypotheses as quickly as possible? If you want to do that, you don’t just write code. You actually pause, think, design, and then write GOOD code. Expert-level code. It’s very likely that if I had to cover all the tips Jeremy pulled out of his magic hat, this post would be a short book, so I decided to pick the one that stuck with me so far: callbacks. Of course.
The Learner
In order to dig into callbacks, I’ll peel the onion of the Learner
class introduced in notebook 9 here (look for the “Updated versions since the lesson” section). A Learner
is an object that encapsulates everything we need to know about a model. The data, the loss function, the evaluation metrics, and, well, the model itself. It looks like this
class Learner():
def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
cbs = fc.L(cbs)
fc.store_attr()
def cb_ctx(self, nm): return _CbCtxInner(self, nm)
def one_epoch(self, train):
self.model.train(train)
self.dl = self.dls.train if train else self.dls.valid
with self.cb_ctx('epoch'):
for self.iter,self.batch in enumerate(self.dl):
with self.cb_ctx('batch'):
self.predict()
self.callback('after_predict')
self.get_loss()
self.callback('after_loss')
if self.training:
self.backward()
self.callback('after_backward')
self.step()
self.callback('after_step')
self.zero_grad()
def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
cbs = fc.L(cbs)
# `add_cb` and `rm_cb` were added in lesson 18
for cb in cbs: self.cbs.append(cb)
try:
self.n_epochs = n_epochs
self.epochs = range(n_epochs)
if lr is None: lr = self.lr
if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
with self.cb_ctx('fit'):
for self.epoch in self.epochs:
if train: self.one_epoch(True)
if valid: torch.no_grad()(self.one_epoch)(False)
finally:
for cb in cbs: self.cbs.remove(cb)
def __getattr__(self, name):
if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
raise AttributeError(name)
def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
@property
def training(self): return self.model.training
Here is how it works in practice. We instantiate an object and call fit
When initializing the Learner
, we pass:
model
: this is the PyTorch model we are going to traindls
: these are the DataLoaders wrapping the train and validation sets (we are using the FashionMNIST dataset downloaded from HuggingFace)F.cross_entropy
: that’s the loss function. We are running a multi-class classification task (10 classes of clothing to be predicted) so we use a standard CE losslr
: the learning rate to be passed to the optimizercbs
: a list of callbacks
First things first. What are callbacks and why are they useful? A callback is a function that gets executed at specific locations in the code modifying attributes of the class it runs into. Let’s put it this way. Without callbacks, if you wanted to print “This is a great batch named x!” after loading a batch x
and before running model(x)
, you’d have to do something like
x, y = batch
print("This is a great batch named x!")
pred = model(x)
Now, what if you want to also print “Btw, y is part of the batch too!”? Easy. I just add another print statement
x, y = batch
print("This is a great batch named x!")
print("Btw, y is part of the batch too!")
pred = model(x)
I know what you are thinking. What’s your point here? My point is that all of the above is not elegant. What we are doing is modifying a core component of the training loop (x, y = batch; pred = model(x)
) with additional noisy code. Sooner rather than later, it’s going to become very hard to maintain, especially if the logic is more complex than just print statements (hopefully it is). What we would like to have instead is something like
x, y = batch
run_anything_you_want_after_loading_a_batch() # <-- this is a callback
pred = model(x)
Where run_anything_you_want_after_loading_a_batch
is a function in charge of executing the needful after loading the batch and before feeding it to the model. Like our print statements above. Want to add another print statement? Do it inside run_anything_you_want_after_loading_a_batch
. That’s a function defined elsewhere, so you don’t clutter the main training loop and you keep things clean. You hopefully see how this becomes even more critical and useful when you start thinking about how to run code before/after calculating the loss, before/after running the backward pass, before/after each epoch, etc. The training loop is going to turn into a mess if you don’t stop for a second and give some deep thought to how to tidy things up.
The Callback
magic
Ok, this is cool in theory, but how are we going to make this happen in practice? This is what callbacks are for. This is what self.cb_ctx('something')
and self.callback('after_something')
do (see the Learner
). This is what allowed us to keep the entire Learner
class below 50 lines of code. Everything else happens outside. In separate callback classes. Let’s see how.
We start from the beginning. From the Learner initialization: learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)
The Learner.__init__
does nothing else than storing the attributes we pass to the constructor into self
(e.g. so that we call self.lr
across the class) and converting cbs
from a python list
into a fastcore (fc
) L
object. If you are not familiar with it, fc.L
is basically a list
on steroids, adding nice functionality to the standard python object.
Once initialized, we can call fit
on it (learn.fit
).
Inside fit
we first append additional callbacks to the original list (if passed to the method). We then store the number of epochs and instantiate the optimizer (SGD by default). Then the callbacks magic starts to unfold.
Context managers
What is this with self.cb_ctx('fit')
?
- It invokes
cb_ctx(self, nm)
which in turn returns_CbCtxInner(self, nm)
. _CbCtxInner
is a custom-defined context manager. A context manager is a python tool making it possible to control what gets executed before and after the code snippet included in thewith
statement. A classic example is IO handling, e.g.with open(filename, "r+") as f: do_stuff(f)
, which opens and closes the file for us without us even realizing it. It does that before and afterdo_stuff
.- In order to define a context manager from scratch we have to write a class with an
__enter__
and an__exit__
method. Those are the methods being invoked before (__enter__
) and after (__exit__
) our code in thewith
statement. Let’s go check out_CbCtxInner
class _CbCtxInner:
def __init__(self, outer, nm): self.outer,self.nm = outer,nm
def __enter__(self): self.outer.callback(f'before_{self.nm}')
def __exit__ (self, exc_type, exc_val, traceback):
chk_exc = globals()[f'Cancel{self.nm.title()}Exception']
try:
if not exc_type: self.outer.callback(f'after_{self.nm}')
return exc_type==chk_exc
except chk_exc: pass
finally: self.outer.callback(f'cleanup_{self.nm}')
The __enter__
method invokes self.outer.callback(f'before_{self.nm}')
. The outer
argument to the __init__
corresponds to self
in the Learner
class, e.g. outer
refers to the Learner
object. This means that self.outer.callback
translates into learn.callback
(line 45 of the Learner
class). nm
is instead 'fit'
, so we are basically invoking learn.callback('before_fit')
.
What does this do? It runs run_cbs(learn.cbs, 'before_fit', learn)
. Let’s check it out
def run_cbs(cbs, method_nm, learn=None):
for cb in sorted(cbs, key=attrgetter('order')):
method = getattr(cb, method_nm, None)
if method is not None: method(learn)
run_cbs
:
- loops through all the callbacks defined at the
Learner
level (sorted byorder
) - checks if each callback has the requested
method_fn
(before_fit
in this case) by usinggetattr
- if not, it does nothing. If yes, it applies the method to the
Learner
object (method(learn)
)
Going through the TrainCB()
, DeviceCB()
and metrics
callbacks
The callbacks we had defined (screenshot 1) are cbs = [TrainCB(), DeviceCB(), metrics]
(metrics = MetricsCB(accuracy=MulticlassAccuracy())
). Here they are in code:
class Callback(): order = 0
class TrainCB(Callback):
def __init__(self, n_inp=1): self.n_inp = n_inp
def predict(self, learn): learn.preds = learn.model(*learn.batch[:self.n_inp])
def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds, *learn.batch[self.n_inp:])
def backward(self, learn): learn.loss.backward()
def step(self, learn): learn.opt.step()
def zero_grad(self, learn): learn.opt.zero_grad()
class DeviceCB(Callback):
def __init__(self, device=def_device): fc.store_attr()
def before_fit(self, learn):
if hasattr(learn.model, 'to'): learn.model.to(self.device)
def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)
class MetricsCB(Callback):
def __init__(self, *ms, **metrics):
for o in ms: metrics[type(o).__name__] = o
self.metrics = metrics
self.all_metrics = copy(metrics)
self.all_metrics['loss'] = self.loss = Mean()
def _log(self, d): print(d)
def before_fit(self, learn): learn.metrics = self
def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]
def after_epoch(self, learn):
log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
log['epoch'] = learn.epoch
log['train'] = 'train' if learn.model.training else 'eval'
self._log(log)
def after_batch(self, learn):
x,y,*_ = to_cpu(learn.batch)
for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
self.loss.update(to_cpu(learn.loss), weight=len(x))
The three of them inherit from the Callback
class without altering the order
attribute. Therefore they all have an order
of 0, meaning they will be accessed and executed in the same order they appear in the cbs
list (TrainCB
first, DeviceCB
second, and MetricsCB
third). The loop goes as follows:
- Does the
TrainCB
callback have abefore_fit
method? No. Do nothing and keep going. - Does the
DeviceCB
callback have abefore_fit
method? Yes! Let’s invoke it on theLearner
. What does it do?if hasattr(learn.model, 'to'): learn.model.to(self.device)
. It puts the model ondef_device
(cuda
in my case, as I am running on a GPU machine). - Does the
MetricsCB
callback have abefore_fit
method? Yes! Let’s invoke it on theLearner
. What does it do?learn.metrics = self
. It defines themetrics
attribute of theLearner
by pointing to itself (self
, theMetricsCB
class).
What happens next? The code inside the with self.cb_ctx('fit'):
statement gets executed, e.g.
for self.epoch in self.epochs:
if train: self.one_epoch(True)
if valid: torch.no_grad()(self.one_epoch)(False)
Exiting the context manager
Before stepping into the for
loop, let’s reflect on how the context manager exits. Once looped through all the epochs, the __exit__
method of the _CbCtxInner
class kicks in. As you can see from the code, it looks for the CancelFitException
within the globally defined variables and then it applies the same logic we explained before for before_fit
. This time it applies it to after_fit
. It loops through the callbacks, looks for an after_fit
method, and if it finds it, it runs it. If at any point in time, the CancelFitException
is raised, the loop exits. Or better, it doesn’t just throw an exception and errors out. The “exit” is handled gracefully. Let’s see how with an additional callback.
class SingleBatchCB(Callback):
order = 1
def after_batch(self, learn): raise CancelFitException()
Let’s assume we had added SingleBatchCB
to cbs
. In this case this callback would get invoked last (order = 1
). Specifically, its after_batch
method would get invoked after the after_batch
method of all the other callbacks. When the runtime hits SingleBatchCB.after_batch
, it raises a CancelFitException
.
- This means the
try/except
statement within the__exit__
method of the context manager at the batch level catches the exception, stops iterating through batches, and moves a level up. At the epoch level. - Here the
after_epoch
of each callback (if existent) is executed. Thetry/except
statement within the__exit__
method of the context manager at the epoch level doesn’t throw an exception but it returnsFalse
(return exc_type==chk_exc
→return CancelFitException==CancelEpochException
). Which means that the epoch loop is broken too. - We move to another level. The last and highest. The fit level. Here the
after_fit
method of each callback (if existent) is executed. Thetry/except
statement within the__exit__
method of the context manager at the fit level doesn’t throw an exception and it returnsTrue
(return exc_type==chk_exc
→return CancelFitException==CancelFitException
), killing the fit.
Stop one second to appreciate the beauty of what we just achieved. Imagine how complex this logic would be, hadn’t we used context managers and the dynamic flexibility of callbacks. And we are only at the fit
stage! We were about to dive into epochs. Here the approach is the same. with self.cb_ctx('epoch')
encapsulates before_epoch
and after_epoch
operations, looping through the 3 callbacks we have defined, and executing the relevant methods. with self.cb_ctx('batch')
works the same, at the batch level.
What’s this self.predict()
thing?
Inside the batch operations, something unusual and very interesting happens. We invoke self.predict()
(line 14 of the Learner
 class). Nowhere do we have a predict
method defined though. Same holds, a couple of rows below, for get_loss
, backward
, step
and zero_grad
. The Learner
class indeed does not have those methods. This is the point where the an AttributeError
should be thrown. Unless we override the __getattr__
method, which we do at line 41.
def __getattr__(self, name):
if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
raise AttributeError(name)
What happens here is the following. When looking for any of the 'predict'
, 'get_loss'
, 'backward'
, 'step'
, and 'zero_grad'
methods we invoke self.callback
(via partial
) passing the method in question as method_nm
. So when we hit predict
, we actually call self.callback('predict')
. Which calls run_cbs
. Which loops through all the callbacks the Learner
has, checking (and executing if found) the predict
method into each one of them. In our specific example, only the TrainCB
callback has a predict
method defined. It does what we expect it to do. Run the batch through the model and return predictions. Why are we going through all this trouble? Because, together with everything happening AROUND the training loop (this is how we introduced callbacks in the first place), now we have abstracted the training loop itself. Every bit of it.
Adding momentum in 3 lines of code
Why do we care? Look at how simple it becomes to implement SGD with momentum.
def zero_grad(self):
with torch.no_grad():
for p in self.model.parameters(): p.grad *= self.mom
We just have to change the zero_grad
method, either inside TrainCB
or inside a newly defined Learner
class (see TrainLearner
and MomentumLearner
). How cool is that? That’s literally the only change we have to make to get momentum to work.
All right, there is quite a bit of info to digest here. What I recommend doing is cloning the repo, opening up a notebook, and starting experimenting with miniai. The implementation is transparent and concise enough to allow us to edit parts of the code and immediately see what is happening. Give it a go and familiarize yourself with the magic of callbacks! It’s worth it!