Looking for a coding video?

If you are looking for videos related to software development, please search!

Interpreting Random Forest and other black box models like XGBoost

Source link

In machine learning there’s a recurrent dilemma between performance and interpretation. Usually, the better the model, the more complex and less understandable.

Generally speaking, there are two way to interpret a model:

  1. Overall interpretation: determine which variables (or combinations of variables) have the most predictive power, which ones have the least
  2. Local interpretation: for a given data point and associated prediction, determine which variables (or combinations of variables) explain this specific prediction

According to the type of model you use, specific ways to interpret your model may exist. For instance, Decision Tree models can be interpreted simply by plotting the tree and seeing how splits are made and what are the leafs’ composition.

However there’s no specific way to do that with RandomForest or XGBoost, which are usually better at making predictions.

Overall interpretation

The overall interpretation already comes out of the box in most models in Python, with the “feature_importances_” property. Example below:

Example of feature importances for a given model, what I call “feature importances table” in this article (sorted by feature importance in descending order)

Interpreting this output is quite straightforward: the more importance, the more relevant the variable is, according to the model. This a great way to

  1. identify the variables with the best predictive power
  2. raise issues/correct bugs: variables that have too much importance compared to others. 
    Example: In a previous project, we worked with biased data: data of class 1 had a lot of missing values in a variable, that data of class 0 did not. We didn’t realize it until we looked at the feature importance table. The model learnt that if the data is missing, it belongs to class 1. We solved it by sampling the missing values with data from class 0
  3. update your model with new variables. To see whether a new variable is relevant to your model, compute both the feature importances for the model before (without the new variable), and for the model after (with the new variable). Analyze the shifts the new variable produces in the feature importances table.
    Example: when doing feature engineering, you can come up with a more relevant feature, but introducing it in your data will probably cut the feature importances of features directly correlated to the new feature.
  4. compare different models: compare the feature importances for two different models (RandomForest vs XGBoost for instance) by comparing how important the variables are. It can help to see if a model grasps the predictive power of a variable. 
    Example: comparing XGBoost models with different depths can help you understand that a specific variable becomes useful when you use a specific depth.

So far so good for an overall understanding of the model.

Now how do you explain a prediction for a given data point?

Local interpretation

Here I will define what local interpretation is and propose a methodology with a workaround to do it easily with any model you have.

How to define local interpretation?

What inspired me here is a demo I had from DataRobot where they want to predict loan defaults. In their demo, for each individual prediction, they also output the top 3 variables that increased the most the probability of default, and the top 3 variables that decreased it the most.

Let’s keep this example (with simulated data), and for more readability, let’s represent local interpretation with only 3 variables as shown below:

Illustration of local interpretation: for each data point, we identify the 3 variables with the most impact on the prediction of default. Variable Var_1 increases the probability of default in the first 2 predictions (resp. +34% and +25%), but decreases it in the 3rd (-12%)

The most interesting takeaways from this interpretation

Interpreting each individual prediction can be used to:

  1. understand for an individual case the reasons of the prediction. 
    Example: two individuals can have a high probability of default but for completely different reasons (i.e. different variables)
  2. understand on a filtered population the most frequent reasons of their predictions. 
    Example: on all individuals with a probability of default above 50%, what are the most frequent variables that explain the default
For all predictions of defaults (probability > 50%), we rank the variables that are the most frequent in the top 3 variables, to understand which variables explain the most the default. Variable Var_1 is in 43 cases the most contributing variable for the prediction of default, in 15 cases the second, in 12 cases the 3rd

Implementation in Python

The treeinterpreter library in Python allows us to compute exactly the impact of each feature for a Random Forest model. I let the curious reader check the two amazing articles (1 and 2) from the author of the package.

For other models, we will do a quick-and-dirty solution: run a Random Forest model, and do local interpretations where predictions between your model and the Random Forest model match (when they both simultaneously predict default or non default). It is the solution I chose in a client project where I had a XGBoost model. In this case, the local interpretation from Random Forest made a lot of sense, but it is still a frustrating workaround not to have a dedicated framework for XGBoost specifically.

As it takes some time to compute (depending on the number of trees in the Random Forest model), I recommend using a subset of your predictions for this exercise. For instance, the 10 individuals most likely to default according to the model.

For each individual prediction, we compute the individual contribution of each variable in the prediction with the treeinterpreter package

Note that treeinterpreter has both variable contribution, and overall bias. For more in-depth understanding, I recommend the original blog post.

One thing to keep in mind

Imagine you have a variable “age”, and its contribution is high enough to be in the top 3 variables to contribute to a specific prediction. You would probably be interested in knowing what the age is, because you (or an expert) will interpret differently an age of 18 or 35. So it’s a good habit to look at the contribution of a variable along with its value (like in the code above, where there are the two columns ‘value_variable’ and ‘contribution_variable’)

To go beyond treeinterpreter

Interpreting black-box models has been the subject of many research papers and is currently, especially when it comes to deep learning interpretation. Different methods have been tested and adopted: LIME, partial dependence plots, defragTrees

For treeinterpreter, it would be great to have other tree-based models, like XGBoost, LightGBM, CatBoost, or other gradient boosting methods.

I recently stumbled upon a great article on interpretation of random forest models: https://medium.com/usf-msds/intuitive-interpretation-of-random-forest-2238687cae45, with waterfall plots and partial dependence plots.

Thanks for reading, and if you found this article useful, please consider giving it at least 50 claps 🙂

Source link


Leave a Reply

You must be logged in to post a comment.