4 min read

Machine Learning from Scratch: Decision Tree

Table of Contents

Introduction

In this post, I’ll be implementing a decision tree classifier from scratch in Python. This is the fifth post in the “Machine Learning from Scratch” series.

Decision trees are intuitive, interpretable models that make decisions by learning simple rules from the training data. They form the foundation for powerful ensemble methods like Random Forests and Gradient Boosting.

Decision Tree

A decision tree is a tree-structured classifier where internal nodes represent features, branches represent decision rules, and leaf nodes represent class labels. The tree is built by recursively splitting the data based on features that best separate the classes.

The quality of a split is measured using metrics like Information Gain, which is based on entropy. Entropy measures the impurity or randomness in a set of labels:

Entropy = -Σ p(x) * log₂(p(x))

We choose splits that maximize information gain, which is the reduction in entropy after a split.

Implementation

I’m using numpy for numerical computations and Counter from collections for counting class occurrences. For testing, I’ll use train_test_split and datasets from scikit-learn.

The implementation includes a Node class to represent tree nodes and a DecisionTree class with the following methods:

  • __init__: Constructor to set the maximum depth and minimum samples for splitting.
  • fit: Method to build the tree recursively.
  • predict: Method to traverse the tree and make predictions.
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn import datasets

class Node:
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

    def is_leaf_node(self):
        return self.value is not None


class DecisionTree:
    def __init__(self, max_depth=10, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root = None

    def fit(self, X, y):
        self.root = self._grow_tree(X, y)

    def _grow_tree(self, X, y, depth=0):
        num_samples, num_features = X.shape
        num_labels = len(np.unique(y))

        if depth >= self.max_depth or num_labels == 1 or num_samples < self.min_samples_split:
            leaf_value = self._most_common_label(y)
            return Node(value=leaf_value)

        feat_idxs = np.random.choice(num_features, num_features, replace=False)
        best_feature, best_threshold = self._best_split(X, y, feat_idxs)

        left_idxs = np.argwhere(X[:, best_feature] <= best_threshold).flatten()
        right_idxs = np.argwhere(X[:, best_feature] > best_threshold).flatten()

        left = self._grow_tree(X[left_idxs, :], y[left_idxs], depth + 1)
        right = self._grow_tree(X[right_idxs, :], y[right_idxs], depth + 1)

        return Node(best_feature, best_threshold, left, right)

    def _best_split(self, X, y, feat_idxs):
        best_gain = -1
        split_idx, split_threshold = None, None

        for feat_idx in feat_idxs:
            X_column = X[:, feat_idx]
            thresholds = np.unique(X_column)

            for threshold in thresholds:
                gain = self._information_gain(y, X_column, threshold)

                if gain > best_gain:
                    best_gain = gain
                    split_idx = feat_idx
                    split_threshold = threshold

        return split_idx, split_threshold

    def _information_gain(self, y, X_column, threshold):
        parent_entropy = self._entropy(y)

        left_idxs = np.argwhere(X_column <= threshold).flatten()
        right_idxs = np.argwhere(X_column > threshold).flatten()

        if len(left_idxs) == 0 or len(right_idxs) == 0:
            return 0

        n = len(y)
        n_l, n_r = len(left_idxs), len(right_idxs)
        e_l, e_r = self._entropy(y[left_idxs]), self._entropy(y[right_idxs])
        child_entropy = (n_l / n) * e_l + (n_r / n) * e_r

        return parent_entropy - child_entropy

    def _entropy(self, y):
        hist = np.bincount(y)
        ps = hist / len(y)
        return -np.sum([p * np.log2(p) for p in ps if p > 0])

    def _most_common_label(self, y):
        counter = Counter(y)
        return counter.most_common(1)[0][0]

    def predict(self, X):
        return np.array([self._traverse_tree(x, self.root) for x in X])

    def _traverse_tree(self, x, node):
        if node.is_leaf_node():
            return node.value

        if x[node.feature] <= node.threshold:
            return self._traverse_tree(x, node.left)
        return self._traverse_tree(x, node.right)

Now let’s test the model on a classification dataset.

def accuracy(y_test, predictions):
    return np.sum(y_test == predictions) / len(y_test)


if __name__ == '__main__':
    X, y = datasets.make_classification(
        n_samples=1000, n_features=10, n_classes=3, 
        n_informative=8, random_state=42
    )
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    model = DecisionTree(max_depth=10)
    model.fit(X_train, y_train)
    predictions = model.predict(X_test)

    acc = accuracy(y_test, predictions)
    print(f"Accuracy: {acc}")

The decision tree achieves good accuracy on the test set. Decision trees are powerful because they can capture non-linear relationships and interactions between features without explicit feature engineering.

That’s all for this post. Thanks for reading!