Had you asked me to explain the output of a Random Forest model six months ago, without thinking twice, I would have replied the following: “This is impossible. How do you suggest figuring out why an average of N trees trained on bagged data on a random selection of features decided the way they decided? These kind of algorithms are clearly not meant to be interpretable. They are meant to be accurate. If you want something interpretable, go for a linear model.”
Fortunately, I was terribly wrong.
The newly educated me would reply, instead: “This is totally feasible. A decision tree is one of the most interpretable ML algorithms out there. Its learning process nicely mimics the way the human brain parses information, via IF-ELSE scenarios. A Random Forest is nothing else than an average of a lot of these IF-ELSE decisions, so it is relatively easy to understand what is going on”.
The one resource which opened my eyes was fast.ai Introduction to Machine Learning for Coders. If you are expecting it to cover the usual range of statistical learning algorithms, from Linear Regression to KNN strolling through SVM, then be ready to be disappointed (or not). The first half of the course is devoted to Random Forest. The second half to Neural Networks. These two approaches, according to Jeremy Howard, are the only ones any ML practitioner would ever need to get the job done. Keeping Deep Learning aside, watching seven 2-hour-long lectures just on Random Forest was initially a big surprise, but then completely mind-blowing. I was amazed by how little I knew about a method I thought to be very familiar with, and by how powerful it is. The main idea Jeremy wants to evangelize is that, even though it is true that linear models are easy to interpret, in most cases, they are still fundamentally useless. That’s due to the fact that linear relationships are (almost) never true, hence we end up drawing conclusions from coefficients which don’t tell us pretty much anything about what is really going on. As a Data Scientist, I am a strong advocate of linear algorithms. They train in no time and they provide good enough answers in a lot of situations. It is also true, though, that they suffer from very high bias and that usually represent a big problem. It would be nice to have another approach with the same level of interpretability but more flexible than a straight line. Trees Ensembles is what we need! As soon as I started exploring this path I was completely sold to the idea.
Below I will cover the main points I learned from this journey. Things which I literally was not aware of just before the summer! In the meantime let me share the additional series of resources I found useful. Not in a specific order, here they are:
- This blog post, with amazing visualizations, demystifying ensemble methods.
- Interpretable Machine Learning: this is an easy-to-read and super clear online book diving into the details of how to interpret ML algorithms (whenever possible).
- Ando Saabas blog. This is a must. For the ones not knowing him, Ando is a Senior Data Scientist at Taxify (yes, I have the immense honor of working with him!) and author of the treeinterpreter python package (whose underlying math and logic are actually explained in his blog). This was the first package addressing the broad issue of opening the ML black box. It still represents a reference in the field and it powers other powerful packages such as eli5.
As I usually do when I want to experiment with something new, I grabbed a dataset from Kaggle and play around with it. In this case, I went for the House Prices one, coming from a playground competition in which Kagglers are challenged to predict the price of a real estate property given its attributes. Here what I learned from this exercise.
(Link to the Jupyter Notebook)
eli5 and features PermutationImportance
When using Tree Ensemble methods in scikit-learn, the feature importance attribute we can extract from the model instance is calculated looking at how much, on average, a split on a specific variable reduces a metric of interest. In case of regressors, the MSE. Another way of checking the impact of a feature is to shuffle it, predict on the broken dataset and note down the decrease in accuracy. The higher the loss the higher the importance. Turns out eli5 supports this kind of approach, under the PermutationImportance method. As the routine runs the random shuffling N times, it also comes with a nice distribution of “importances”, which can be conveniently plotted. Here one of the charts I generated during this exercise.
Removing redundant features via clustering columns
Simplify whenever possible. This is a golden rule which always applies. ML models are no exception. No doubt it is preferable to train models on fewer variables, if that comes at a reasonable cost in accuracy. Lighter models are faster to train, less memory consuming and easier to interpret, as we are extracting signal from a condensed set of variables. The way I usually deal with reducing the feature space is simply building a correlation matrix and trashing variables with a corr value higher than an arbitrary threshold. I might go with PCA if model interpretation is not relevant or even simply ditching any variable with an tree importance lower than 0.05, or a somewhat low value compared to the top ones. Another interesting approach is to apply clustering. This technique is generally employed to find patterns across rows, i.e. locate points resembling each other according to their features. It is rarely used to cluster features themselves. The results are pretty interesting though. Specifically, hierarchical clustering builds dendrograms offering a concise visualization of the feature space. Variables close to each other in terms of branches’ splits are “similar”, hence, most likely, carrying redundant information. Here how the dendrogram looks like for the housing dataset.
Features contributions to model predictions, aka treeinterpreter’s magic!
This is, by far, the most important piece of ML-related knowledge I recently acquired. I won’t try to go deep into the details of what hides behind the scenes here. Ando Saabas does an incredible job on his blog, both here and here. He invented the method, so no surprise! I will just recap at a high level what these figures mean and show an example from the housing dataset.
Disclaimer: for the sake of simplicity I will refer to a regression problem and to a single tree with a max depth of 3. The extension to classification is trivial. As for a Forest, it is just a matter of averaging results from multiple trees. No big deal.
The idea of associating a numeric contribution to a specific variable comes from noticing how the dependent variable average value changes following a path down a tree. Let’s see what this means.
- At the root, before splitting anything at all, the best guess for the model is the straight mean of the output. Let’s say this value is +100.
- Leaving the root, we encounter the first split. Say on feature X1. We pick the right branch and fall in a region where the mean dependent variable is +80. There is no harm in claiming that X1 contributed by -20 to the output.
- We go ahead. Now the tree splits on X2 and, taking the right branch, lands us on an average of +90. Good. X2 contributed to a +10 increase.
- X3’s turn, whose left branch decreases the output’s mean to +60. X3 provided a further contribution of -30.
- The final prediction for the leaves in this node is +60.
To recap, the prediction can be broken into
+100 [Y average] + (-20) [from X1] + (10) [from X2] + (-30) [from X3] = +60 [final result]
This is extremely powerful as we get two valuable outputs:
- Row level feature contributions, i.e. we can literally answer the question “why does THAT specific house have THAT specific price?” A waterfall chart is particularly indicated in this case as it nicely visualizes the bridge between the dataset average and the final price. Below an example from the dataset. I also recommend checking out this python package, which makes drawing waterfall charts trivial.
- Average feature contribution, i.e. “across the entire dataset, how is X1 affecting price compared to average?” Again, you can find a pretty self-explanatory visualization of this point across train and validation sets, below.
Partial dependence plots, aka stop looking at univariate relationships!
After treeinterpreter, this is for sure the second most important learning. The question we have is “how does feature X1 affect the dependent variable? Are they positively correlated? Maybe it is fairly linear, but only in certain X1 ranges, and then becomes quadratic after a certain threshold?”. One might be tempted to draw a scatter plot of price (our dependent variable) versus X1 and answer all the previous questions in that way. DON’T DO THIS! Relying on univariate plots is very risky as you are not accounting for the effect of all the other variables. How do you know that the curve you are looking at is, in fact, not affected by interactions with other features? Truth is you don’t know.
Partial Dependence Plots (DPD) to the rescue. Before doing anything else, train a Random Forest on your dataset. Now, let’s go back to our feature X1. Suppose it is a numeric variable which only appears in 5 distinct values in the data. Literally 1, 2, 3, 4 and 5. Easy. What we want to know is what is the relationship between X1 and price all other variables being equal. This detail is of paramount importance, as it means I am accounting for the interactions with the other columns as well. To do this we proceed in the following way (pseudocode):
prices_at_different_X1_all_rest_equal = [] for X1_value in [1, 2, 3, 4, 5]: data.X1 = X1_value #replace X1 column with a unique value new_price = model.predict(data) average_price_for_X1_value = mean(new_price) prices_at_different_X1_all_rest_equal.append(average_price_for_X1_value) plot(x=[1, 2, 3, 4, 5], y=prices_at_different_X1_all_rest_equal)
Basically, we simulate house prices in different scenarios, fixing the entire dataset except for the column we want to measure the sensitivity for. Neat.
The plot we get allows us to safely answer the above questions as what we look at is how price changes with X1, all the rest being equal. Below an example from the housing dataset, generated with the pdpbox library. I am adding a double interactions plot too, constructed following the same principles, but in two dimensions.