Linear Discriminant Analysis (LDA)
Learning Objectives
By the end of this section, you should be able to:
- Understand the difference between supervised (LDA) and unsupervised (PCA) methods
- Explain how LDA maximizes group separation
- Apply LDA for classification and visualization
- Interpret discriminant functions and loadings
- Validate LDA models with cross-validation
What is LDA?
Linear Discriminant Analysis (LDA) finds linear combinations of features that best separate predefined classes by maximizing between-class variance and minimizing within-class variance. It is supervised because it uses known class labels during training to learn the separation.
Key Distinction
- PCA (unsupervised): Finds directions of maximum variance, ignoring group labels
- LDA (supervised): Finds directions of maximum group separation, using group labels
When to Use LDA
- Classification - Predicting group membership for new samples
- Visualization - Plotting samples in space that maximizes group separation
- Feature selection - Identifying variables that distinguish groups
- Dimension reduction - With class information included
Common applications:
- Classifying cancer subtypes from gene expression
- Predicting disease state from microbiome profiles
- Distinguishing plant species from morphological traits
- Separating experimental conditions in metabolomics
LDA vs. PCA: A Visual Comparison
Consider two groups with the same variance structure:
# Generate example data
set.seed(42)
group1 <- cbind(rnorm(50, 0, 1), rnorm(50, 0, 0.3))
group2 <- cbind(rnorm(50, 1.5, 1), rnorm(50, 0.3, 0.3))
data <- rbind(group1, group2)
groups <- factor(rep(c("A", "B"), each = 50))
# PCA - finds maximum variance (diagonal direction)
pca <- prcomp(data, scale. = FALSE)
# LDA - finds maximum separation (horizontal direction)
library(MASS)
lda_model <- lda(groups ~ data[,1] + data[,2])
Key difference:
| Method | Goal | Uses Groups? | Output |
|---|---|---|---|
| PCA | Maximize variance | No | Principal Components |
| LDA | Maximize separation | Yes | Discriminant Functions |
The Mathematical Idea
LDA finds directions that:
- Maximize between-group variance (groups far apart)
- Minimize within-group variance (groups compact)
Fisher's Linear Discriminant
For two groups, LDA finds the linear combination w that maximizes:
Where:
- S_B = between-group scatter matrix
- S_W = within-group scatter matrix
- w = discriminant weights (loadings)
Intuition
Think of LDA as finding the best "projection line" where group means are far apart, but points within each group are close together.
LDA Assumptions
Critical Assumptions
LDA makes stronger assumptions than PCA:
- Multivariate normality - Each group follows a multivariate normal distribution
- Homogeneity of variance - All groups have the same covariance matrix
- Independence - Observations are independent
- No multicollinearity - Predictor variables not highly correlated
Checking Assumptions
# 1. Visual check of normality
library(ggplot2)
library(reshape2)
# Check each variable by group
data_long <- melt(data.frame(data, group = groups))
ggplot(data_long, aes(x = value, fill = group)) +
geom_density(alpha = 0.5) +
facet_wrap(~variable, scales = "free")
# 2. Test homogeneity of covariance (Box's M test)
library(biotools)
boxM(data, groups)
# Non-significant p-value = assumption met
# 3. Check multicollinearity
cor(data)
# High correlations (|r| > 0.8) indicate problems
What if Assumptions Are Violated?
- Transform data (log, sqrt, Box-Cox)
- Use Quadratic Discriminant Analysis (QDA) instead
- Try logistic regression or random forests
- Use QDA (allows different covariances per group)
- Apply regularized LDA
- Use robust methods
- Regularized LDA (when p > n)
- Reduce features with PCA first
- Use cross-validation carefully
LDA in R: Complete Example
Basic Implementation
library(MASS)
# Use iris dataset
data(iris)
df <- iris[, 1:4] # Predictors
species <- iris$Species # Groups
# 1. Fit LDA model
lda_model <- lda(species ~ ., data = iris)
# 2. Examine output
print(lda_model)
#> Prior probabilities of groups:
#> setosa versicolor virginica
#> 0.333 0.333 0.333
#>
#> Group means: (shows centroids)
#>
#> Coefficients of linear discriminants:
#> (shows loadings for each LD)
# 3. See how well groups separate
plot(lda_model)
Key Components
# Prior probabilities (can be set manually)
lda_model$prior
# Group means (centroids in original space)
lda_model$means
# Scaling (discriminant coefficients/loadings)
lda_model$scaling
#> LD1 LD2
#> Sepal.Length 0.8293776 0.02410215
#> Sepal.Width 1.5344731 2.16452123
#> Petal.Length -2.2012117 -0.93192121
#> Petal.Width -2.8104603 2.83918785
# Proportion of trace (variance explained)
lda_model$svd^2 / sum(lda_model$svd^2)
#> LD1: 99.1% LD2: 0.9%
Making Predictions
# Predict on training data
predictions <- predict(lda_model)
# Components:
# - $class: predicted group
# - $posterior: probability of belonging to each group
# - $x: coordinates in discriminant space
# Confusion matrix
table(Predicted = predictions$class, Actual = species)
# Classification accuracy
mean(predictions$class == species)
#> [1] 0.98 # 98% accuracy
Visualization
library(ggplot2)
# Extract LD scores
lda_data <- data.frame(
LD1 = predictions$x[, 1],
LD2 = predictions$x[, 2],
Species = species
)
# Plot in discriminant space
ggplot(lda_data, aes(x = LD1, y = LD2, color = Species)) +
geom_point(size = 3, alpha = 0.6) +
stat_ellipse(level = 0.95) +
labs(title = "LDA: Iris Species Separation",
x = "LD1 (99.1% of separation)",
y = "LD2 (0.9% of separation)") +
theme_minimal()
Cross-Validation
Never Test on Training Data!
Always validate with independent data or cross-validation.
Leave-One-Out Cross-Validation (LOOCV)
# Built-in LOOCV in lda()
lda_cv <- lda(species ~ ., data = iris, CV = TRUE)
# Cross-validated predictions
table(Predicted = lda_cv$class, Actual = species)
# CV accuracy
mean(lda_cv$class == species)
#> [1] 0.98
k-Fold Cross-Validation
library(caret)
# Set up 10-fold CV
train_control <- trainControl(method = "cv", number = 10)
# Train with CV
set.seed(123)
cv_model <- train(Species ~ .,
data = iris,
method = "lda",
trControl = train_control)
# Results
print(cv_model)
#> Accuracy: 0.98 (averaged over 10 folds)
Train-Test Split
# Split data
set.seed(123)
train_idx <- sample(1:nrow(iris), 0.7 * nrow(iris))
train_data <- iris[train_idx, ]
test_data <- iris[-train_idx, ]
# Train on training set
lda_train <- lda(Species ~ ., data = train_data)
# Predict on test set
test_pred <- predict(lda_train, newdata = test_data)
# Test accuracy
mean(test_pred$class == test_data$Species)
Interpretation
1. Discriminant Loadings
Which variables drive separation?
# Get loadings
loadings <- lda_model$scaling
# Visualize
library(reshape2)
loadings_long <- melt(loadings)
names(loadings_long) <- c("Variable", "LD", "Loading")
ggplot(loadings_long, aes(x = Variable, y = Loading, fill = LD)) +
geom_bar(stat = "identity", position = "dodge") +
coord_flip() +
labs(title = "LDA Loadings") +
theme_minimal()
Interpretation:
- Large absolute values → variable strongly influences that LD
- Positive loading → higher values push toward positive LD scores
- Negative loading → higher values push toward negative LD scores
2. Group Centroids
Where are groups located in discriminant space?
# Calculate group means in LD space
group_means <- predictions$x %>%
as.data.frame() %>%
mutate(Species = species) %>%
group_by(Species) %>%
summarise(across(everything(), mean))
print(group_means)
3. Posterior Probabilities
How confident are predictions?
# Posterior probabilities
head(predictions$posterior)
#> setosa versicolor virginica
#> [1,] 1.000 0.000 0.000
#> [2,] 1.000 0.000 0.000
#> [3,] 0.998 0.002 0.000
# Plot confidence
library(tidyverse)
post_df <- as.data.frame(predictions$posterior) %>%
mutate(Sample = 1:n(),
True_Species = species,
Predicted = predictions$class)
# Samples with low confidence
uncertain <- post_df %>%
rowwise() %>%
mutate(Max_Prob = max(c_across(setosa:virginica))) %>%
filter(Max_Prob < 0.9)
print(uncertain)
Biological Examples
Example 1: Cancer Classification
library(MASS)
# Simulated gene expression data
set.seed(42)
n_genes <- 20
n_samples <- 60
# Create data: 3 cancer subtypes
gene_expr <- matrix(rnorm(n_genes * n_samples), ncol = n_genes)
gene_expr[1:20, 1:5] <- gene_expr[1:20, 1:5] + 2 # Subtype 1
gene_expr[21:40, 6:10] <- gene_expr[21:40, 6:10] + 2 # Subtype 2
gene_expr[41:60, 11:15] <- gene_expr[41:60, 11:15] + 2 # Subtype 3
subtypes <- factor(rep(c("TypeA", "TypeB", "TypeC"), each = 20))
# LDA
lda_cancer <- lda(subtypes ~ gene_expr)
# Cross-validation
lda_cv <- lda(subtypes ~ gene_expr, CV = TRUE)
table(Predicted = lda_cv$class, Actual = subtypes)
# Identify discriminant genes
loadings <- lda_cancer$scaling
top_genes <- apply(abs(loadings), 2, function(ld) {
names(sort(ld, decreasing = TRUE)[1:5])
})
print(top_genes)
Example 2: Microbiome Disease Classification
library(phyloseq)
library(MASS)
# Example with GlobalPatterns dataset
data(GlobalPatterns)
gp <- GlobalPatterns
# Filter and transform
gp_filtered <- filter_taxa(gp, function(x) sum(x > 3) > (0.1 * length(x)), TRUE)
gp_clr <- microbiome::transform(gp_filtered, "clr")
# Extract data
otu_clr <- t(as(otu_table(gp_clr), "matrix"))
sample_type <- get_variable(gp_clr, "SampleType")
# LDA
lda_micro <- lda(sample_type ~ otu_clr)
# Predict
pred <- predict(lda_micro)
# Visualize
library(ggplot2)
lda_plot_data <- data.frame(
LD1 = pred$x[, 1],
LD2 = if(ncol(pred$x) > 1) pred$x[, 2] else 0,
SampleType = sample_type
)
ggplot(lda_plot_data, aes(x = LD1, y = LD2, color = SampleType)) +
geom_point(size = 3) +
theme_minimal()
LDA vs. Other Methods
| Method | Type | Uses Groups? | Assumes Equal Cov? | Best For |
|---|---|---|---|---|
| LDA | Linear | Yes | Yes | Balanced groups, normal data |
| QDA | Quadratic | Yes | No | Groups with different spreads |
| PCA | Linear | No | N/A | Exploration, dimension reduction |
| Logistic Regression | Linear | Yes | No | Binary classification, interpret coefs |
| Random Forest | Nonlinear | Yes | N/A | Complex patterns, no assumptions |
When to Use Each
- Groups have similar covariances
- Sample size is moderate (n > 5p)
- Data is approximately normal
- You want interpretable discriminants
- Classification is the goal
- Groups have different covariances
- More flexible boundaries needed
- Enough data per group (n > 10p)
- LDA assumption violations
- No group labels
- Want to explore variance structure
- Preprocessing for other methods
- Dimension reduction needed
- Very small sample size → regularized methods
- Nonlinear patterns → kernel methods, RF
- Binary outcome → logistic regression
- Severe assumption violations → non-parametric
Common Pitfalls
Mistakes to Avoid
1. Not checking assumptions
# BAD: Just fit without checking
lda(groups ~ ., data = data)
# GOOD: Check first
boxM(data, groups) # Homogeneity
# Visual check for normality
2. Too many variables for sample size
# BAD: p > n (more variables than samples)
# LDA will overfit!
# GOOD: Reduce dimensions first
pca <- prcomp(data, scale. = TRUE)
pca_data <- pca$x[, 1:10] # Keep 10 PCs
lda(groups ~ pca_data)
3. Not scaling data
# BAD: Variables on different scales
lda(groups ~ height_cm + weight_kg, data = data)
# GOOD: Scale first or use formula
data_scaled <- scale(data)
lda(groups ~ ., data = data.frame(data_scaled, groups))
4. Using training accuracy only
# BAD: Overly optimistic
pred <- predict(lda_model)
mean(pred$class == groups) # ~100%!
# GOOD: Use cross-validation
lda_cv <- lda(groups ~ ., data = data, CV = TRUE)
mean(lda_cv$class == groups) # Realistic estimate
5. Unbalanced groups
# BAD: Group A (n=100), Group B (n=10)
# LDA will favor Group A
# GOOD: Adjust priors or balance data
lda(groups ~ ., data = data,
prior = c(0.5, 0.5)) # Equal priors
Practical Workflow
Complete LDA analysis workflow:
library(MASS)
library(caret)
# 1. Prepare data
data <- your_data
groups <- your_groups
# 2. Check assumptions
# Normality (visual)
pairs(data, col = groups)
# Homogeneity of variance
library(biotools)
boxM(data, groups)
# Multicollinearity
cor(data)
# 3. Split data
set.seed(123)
train_idx <- createDataPartition(groups, p = 0.7, list = FALSE)
train_data <- data[train_idx, ]
train_groups <- groups[train_idx]
test_data <- data[-train_idx, ]
test_groups <- groups[-train_idx]
# 4. Fit LDA
lda_model <- lda(train_groups ~ ., data = train_data)
# 5. Cross-validate
cv_results <- lda(train_groups ~ ., data = train_data, CV = TRUE)
cv_accuracy <- mean(cv_results$class == train_groups)
cat("CV Accuracy:", cv_accuracy, "\n")
# 6. Test on holdout set
test_pred <- predict(lda_model, newdata = test_data)
test_accuracy <- mean(test_pred$class == test_groups)
cat("Test Accuracy:", test_accuracy, "\n")
# 7. Confusion matrix
confusionMatrix(test_pred$class, test_groups)
# 8. Visualize
lda_scores <- data.frame(
LD1 = test_pred$x[, 1],
LD2 = if(ncol(test_pred$x) > 1) test_pred$x[, 2] else 0,
Group = test_groups,
Predicted = test_pred$class
)
library(ggplot2)
ggplot(lda_scores, aes(x = LD1, y = LD2, color = Group, shape = Predicted)) +
geom_point(size = 3) +
theme_minimal()
# 9. Interpret loadings
loadings <- lda_model$scaling
print(loadings)
Exercises
Practice Problems
Exercise 1: Basic LDA
Use the iris dataset to classify species:
- Perform 10-fold cross-validation
- Calculate accuracy
- Which species is hardest to classify?
- Plot samples in LD1-LD2 space
Solution
library(MASS)
library(caret)
# 1. 10-fold CV
train_control <- trainControl(method = "cv", number = 10)
cv_model <- train(Species ~ .,
data = iris,
method = "lda",
trControl = train_control)
# 2. Accuracy
print(cv_model)
# 3. Confusion matrix
lda_cv <- lda(Species ~ ., data = iris, CV = TRUE)
table(lda_cv$class, iris$Species)
# Virginica and Versicolor overlap most
# 4. Plot
lda_model <- lda(Species ~ ., data = iris)
pred <- predict(lda_model)
plot_data <- data.frame(
LD1 = pred$x[, 1],
LD2 = pred$x[, 2],
Species = iris$Species
)
library(ggplot2)
ggplot(plot_data, aes(x = LD1, y = LD2, color = Species)) +
geom_point(size = 3) +
stat_ellipse()
Exercise 2: Comparing LDA and PCA
Using iris data:
- Perform both LDA and PCA
- Plot samples in both spaces
- Which method separates species better?
- Why?
Solution
# LDA
lda_model <- lda(Species ~ ., data = iris)
lda_pred <- predict(lda_model)
# PCA
pca_model <- prcomp(iris[, 1:4], scale. = TRUE)
# Plot LDA
plot1 <- data.frame(
LD1 = lda_pred$x[, 1],
LD2 = lda_pred$x[, 2],
Species = iris$Species
)
p1 <- ggplot(plot1, aes(LD1, LD2, color = Species)) +
geom_point() + ggtitle("LDA") + theme_minimal()
# Plot PCA
plot2 <- data.frame(
PC1 = pca_model$x[, 1],
PC2 = pca_model$x[, 2],
Species = iris$Species
)
p2 <- ggplot(plot2, aes(PC1, PC2, color = Species)) +
geom_point() + ggtitle("PCA") + theme_minimal()
library(gridExtra)
grid.arrange(p1, p2, ncol = 2)
# LDA separates better because it uses group information
Key Takeaways
Remember These Concepts
✓ LDA is supervised - uses group labels to find discriminant directions
✓ Maximizes separation between groups while minimizing within-group variance
✓ Assumes normality and homogeneity of covariance matrices
✓ Number of LDs = min(p, g-1) where g is number of groups
✓ Always validate with cross-validation or independent test set
✓ Scale your data unless variables are already on similar scales
✓ Use QDA when covariances differ between groups
✓ Interpret loadings to understand which variables drive separation
Further Resources
Textbooks
- Hastie, T. et al. (2009). The Elements of Statistical Learning. Chapter 4.
- Venables, W.N. & Ripley, B.D. (2002). Modern Applied Statistics with S. Chapter 12.
R Packages
MASS- lda() and qda() functionscaret- Unified interface with CVklaR- Classification visualization
Online Resources
When to Use What
- Explore patterns without groups → PCA
- Find natural groups without labels → Clustering
- Classify with labels → LDA (or QDA, logistic regression)