Decision Trees
A decision tree classifies observations by asking a series of binary questions about the predictor variables. At each step the data are split into two groups based on a threshold for one variable, and the process repeats within each group until a stopping criterion is met. The result is a tree of nodes and branches that can be read from top to bottom to understand how any observation is classified.
Decision trees are intuitive and easy to explain, but a single tree is unstable: small changes in the data can produce a very different tree. This weakness motivates random forests, which average across many trees to produce a robust classifier.
Key Concepts
Root node: the first split, applied to all observations. The variable and threshold chosen here explain the most variance in the outcome.
Internal nodes: subsequent splits applied to subsets of the data.
Leaves (terminal nodes): the end points of the tree. Each leaf is assigned the majority class of the observations it contains.
Depth: the number of splits from root to leaf. Deeper trees fit the training data more closely but are more prone to overfitting.
Splitting criterion: the measure used to choose the best split at each node. For classification trees, the most common criterion is the Gini impurity, which measures how often a randomly chosen observation from the node would be misclassified if labelled according to the class distribution at that node.
where p_k is the proportion of observations in class k at the node. A Gini impurity of 0 means all observations at the node belong to one class (pure node); 0.5 is maximum impurity for a two-class problem.
Fitting a Decision Tree in R
library(rpart)
library(rpart.plot)
data(iris)
# Fit a classification tree
tree_model <- rpart(Species ~ .,
data = iris,
method = "class")
# Print the tree rules
print(tree_model)
#> n= 150
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 150 100 setosa (0.33 0.33 0.33)
#> 2) Petal.Length< 2.45 50 0 setosa (1.00 0.00 0.00) *
#> 3) Petal.Length>=2.45 100 50 versicolor (0.00 0.50 0.50)
#> 6) Petal.Width< 1.75 54 5 versicolor (0.00 0.91 0.09) *
#> 7) Petal.Width>=1.75 46 1 virginica (0.00 0.02 0.98) *
Reading the output: each row describes a node. The first number is the node index, followed by the split rule, the number of observations at that node, the number misclassified (loss), the majority class, and the class probabilities in parentheses. Rows ending with * are leaves.
This tree makes only two splits: first on Petal.Length to isolate setosa perfectly, then on Petal.Width to separate versicolor from virginica. Three observations are misclassified in total.
# Plot the tree
rpart.plot(tree_model,
type = 4,
extra = 104, # show class probabilities and % of observations
main = "Decision Tree: Iris Species")
Each node in the plot shows the predicted class, the class probabilities, and the percentage of observations reaching that node. Follow any branch left for the condition being true and right for false.
Overfitting and Pruning
A tree grown without constraints will keep splitting until every leaf is pure, perfectly fitting the training data but generalising poorly to new data. Two parameters control tree complexity:
maxdepth: the maximum number of splits from root to leaf.
cp (complexity parameter): a split is only made if it improves the fit by at least cp. Higher values produce simpler trees.
# Constrain depth
tree_shallow <- rpart(Species ~ .,
data = iris,
method = "class",
control = rpart.control(maxdepth = 2))
# Or prune after fitting using cross-validated cp
printcp(tree_model)
#> CP nsplit rel error xerror xstd
#> 0.50000 0 1.00000 1.16 0.051
#> 0.44000 1 0.50000 0.66 0.046
#> 0.02000 2 0.06000 0.09 0.029
#
# xerror is the cross-validated error. Choose cp where xerror is lowest.
# Prune to the cp with lowest cross-validated error
best_cp <- tree_model$cptable[which.min(tree_model$cptable[, "xerror"]), "CP"]
tree_pruned <- prune(tree_model, cp = best_cp)
rpart.plot(tree_pruned, main = "Pruned Tree")
The printcp() output shows the cross-validated error (xerror) for each tree size. The tree with two splits already achieves a very low cross-validated error on this dataset, confirming that the additional complexity of a deeper tree is not needed.
Training Accuracy
# Predictions on training data
pred_class <- predict(tree_model, type = "class")
# Confusion matrix
table(Predicted = pred_class, Actual = iris$Species)
#> Actual
#> Predicted setosa versicolor virginica
#> setosa 50 0 0
#> versicolor 0 49 5
#> virginica 0 1 45
# Accuracy
mean(pred_class == iris$Species)
#> [1] 0.96
96% accuracy on the training data. Note that training accuracy is optimistic. For a realistic estimate use cross-validation or a held-out test set, as you would for any classifier.
Limitations of a Single Tree
A single decision tree has two important weaknesses:
Instability: a small change in the training data, such as removing a few observations, can produce a completely different tree structure. The tree is highly sensitive to the particular sample it was trained on.
High variance: because a single tree can fit any training set perfectly by growing deep enough, it tends to overfit unless heavily constrained.
Random forests address both problems by training many trees on different bootstrap samples of the data, each using a random subset of predictors at each split, and averaging their predictions. The averaging reduces variance and the randomness between trees ensures they are not all making the same mistakes.
Exercise
Using the iris dataset:
- Fit a decision tree with default settings and plot it
- How many splits does it make? Which variables are used?
- Check the
printcp()output: is there evidence of overfitting? - Prune the tree to the optimal
cpand compare its structure to the unpruned version
Solution
library(rpart)
library(rpart.plot)
# 1. Fit and plot
tree <- rpart(Species ~ ., data = iris, method = "class")
rpart.plot(tree, type = 4, extra = 104)
# 2. Two splits: Petal.Length and Petal.Width
print(tree)
# 3. Cross-validated error
printcp(tree)
# xerror drops sharply at 2 splits and barely changes after
# No strong evidence of overfitting beyond 2 splits
# 4. Prune
best_cp <- tree$cptable[which.min(tree$cptable[, "xerror"]), "CP"]
tree_pruned <- prune(tree, cp = best_cp)
rpart.plot(tree_pruned, type = 4, extra = 104,
main = "Pruned Tree")
# Pruned tree is identical here: the default tree is already near-optimal