4 Easiest ways to visualize Decision Trees using Scikit-Learn and Python – 2024

So guys, In this blog we will see how we can visualize Decision trees using Scikit-Learn in Python. We will actually be able to see how is the Decision Tree making decisions. So without any further due, let’s do it…

A quick overview of Decision Trees

  • Decision Tree is a Supervised Machine Learning Algorithm which means it requires features as well as targets for training.
  • Decision Trees can be used both for Classification and Regression tasks.
  • While training it creates a Binray Tree type of structure where each node is having 2 children; the left represents the tree that will be followed if the parent node condition is True and the right represents the tree that will be followed if the parent node condition is False.
  • In this blog, we will see 4 ways in which we can visualize these trees.

Step 1 – Training a basic Decision Tree

from matplotlib import pyplot as plt
from sklearn import datasets,tree
from sklearn.tree import DecisionTreeClassifier 

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
  • Here we are simply loading Iris data from sklearn.datasets and training a very simple Decision Tree for visualizing it further.

Step 2 – Types of Tree Visualizations

We can visualize the Decision Tree in the following 4 ways:

  1. Printing Text Representation of the tree.
  2. Plot Tree with plot_tree.
  3. Visualize the Decision Tree with graphviz.
  4. Plot Decision Tree with dtreeviz Package.

Let’s visualize Decision trees…

1. Text Representation of the tree

  • First of all, visualizations is the Text Representation which as the name says is the Textual Representation of the Decision Tree.
  • This type of visualization should not be used for trees of depth more than 4-5 as that would become very difficult to interpret.
  • These types of trees are used when we want to print these to logs.
  • Read more about the export_text method.
text_representation = tree.export_text(clf)
print(text_representation)
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2

Save the Text Representation of the tree…

with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)

2. Plot Tree with plot_tree

  • plot_tree method uses matplotlib behind the hood to create these amazing tree visualizations of Decision Trees.
  • This function mainly requires the classifier, target names, and feature names to generate Trees.
  • A node shows information such as decision split, Gini/entropy value, total no of samples, and the estimated split for the next nodes.
  • Read more about the plot_tree method.
fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)
visualize decision tree with plot_tree

Save the Tree Representation of the plot_tree method…

fig.savefig("decistion_tree.png")

3. Visualize the Decision Tree with Graphviz

  • graphviz also helps to create appealing tree visualizations for the Decision Trees.
  • To plot or save the tree first we need to export it to DOT format with export_graphviz method.
  • Read more about the export_graphviz method.
import graphviz
# DOT data
dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=iris.feature_names,  
                                class_names=iris.target_names,
                                filled=True)

# Draw graph
graph = graphviz.Source(dot_data, format="png") 
graph
visualize decision tree with graphviz

Save the Tree Representation of the graphviz method…

graph.render("decision_tree_graphivz")

4. Plot Decision Tree with dtreeviz Package

  • The 4th and last method to plot decision trees is by using the dtreeviz package.
  • Just provide the classifier, features, targets, feature names, and class names to generate the tree.
  • This tree is different in the visualization from what we have seen in the above 2 examples.
  • Read more about the dtreeviz method.
from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz
visualize decision tree with dtreeviz

Save the Tree Representation of the dtreeviz method…

viz.save("decision_tree.svg")

Do let me know if there’s any query regarding visualizing a decision tree by contacting me via email or LinkedIn.

So this is all for this blog folks, thanks for reading it and I hope you are taking something with you after reading this and till the next time …

Read my previous post: How to build OpenCV with Cuda and cuDNN support in Windows

Check out my other machine learning projectsdeep learning projectscomputer vision projectsNLP projects, and Flask projects at machinelearningprojects.net.

Leave a Reply

Your email address will not be published. Required fields are marked *