Site icon Machine Learning Projects

4 Easiest ways to visualize Decision Trees using Scikit-Learn and Python – 2023

Machine Learning Projects

So guys, In this blog we will see how we can visualize Decision trees using Scikit-Learn in Python. We will actually be able to see how is the Decision Tree making decisions. So without any further due, let’s do it…

A quick overview of Decision Trees

Step 1 – Training a basic Decision Tree

from matplotlib import pyplot as plt
from sklearn import datasets,tree
from sklearn.tree import DecisionTreeClassifier 

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)

Step 2 – Types of Tree Visualizations

We can visualize the Decision Tree in the following 4 ways:

  1. Printing Text Representation of the tree.
  2. Plot Tree with plot_tree.
  3. Visualize the Decision Tree with graphviz.
  4. Plot Decision Tree with dtreeviz Package.

Let’s visualize Decision trees…

1. Text Representation of the tree

text_representation = tree.export_text(clf)
print(text_representation)
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2

Save the Text Representation of the tree…

with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)

2. Plot Tree with plot_tree

fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

Save the Tree Representation of the plot_tree method…

fig.savefig("decistion_tree.png")

3. Visualize the Decision Tree with Graphviz

import graphviz
# DOT data
dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=iris.feature_names,  
                                class_names=iris.target_names,
                                filled=True)

# Draw graph
graph = graphviz.Source(dot_data, format="png") 
graph

Save the Tree Representation of the graphviz method…

graph.render("decision_tree_graphivz")

4. Plot Decision Tree with dtreeviz Package

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz

Save the Tree Representation of the dtreeviz method…

viz.save("decision_tree.svg")

Do let me know if there’s any query regarding visualizing a decision tree by contacting me via email or LinkedIn.

So this is all for this blog folks, thanks for reading it and I hope you are taking something with you after reading this and till the next time …

Read my previous post: How to build OpenCV with Cuda and cuDNN support in Windows

Check out my other machine learning projectsdeep learning projectscomputer vision projectsNLP projects, and Flask projects at machinelearningprojects.net.

Exit mobile version