Skip to content

Linear Discriminant Analysis (LDA)

Learning Objectives

By the end of this section, you should be able to:

  • Understand the difference between supervised (LDA) and unsupervised (PCA) methods
  • Explain how LDA maximizes group separation
  • Apply LDA for classification and visualization
  • Interpret discriminant functions and loadings
  • Validate LDA models with cross-validation

What is LDA?

Linear Discriminant Analysis (LDA) finds linear combinations of features that best separate predefined classes by maximizing between-class variance and minimizing within-class variance. It is supervised because it uses known class labels during training to learn the separation.

Key Distinction

  • PCA (unsupervised): Finds directions of maximum variance, ignoring group labels
  • LDA (supervised): Finds directions of maximum group separation, using group labels

When to Use LDA

  • Classification - Predicting group membership for new samples
  • Visualization - Plotting samples in space that maximizes group separation
  • Feature selection - Identifying variables that distinguish groups
  • Dimension reduction - With class information included

Common applications:

  • Classifying cancer subtypes from gene expression
  • Predicting disease state from microbiome profiles
  • Distinguishing plant species from morphological traits
  • Separating experimental conditions in metabolomics

LDA vs. PCA: A Visual Comparison

Consider two groups with the same variance structure:

# Generate example data
set.seed(42)
group1 <- cbind(rnorm(50, 0, 1), rnorm(50, 0, 0.3))
group2 <- cbind(rnorm(50, 1.5, 1), rnorm(50, 0.3, 0.3))
data <- rbind(group1, group2)
groups <- factor(rep(c("A", "B"), each = 50))

# PCA - finds maximum variance (diagonal direction)
pca <- prcomp(data, scale. = FALSE)

# LDA - finds maximum separation (horizontal direction)
library(MASS)
lda_model <- lda(groups ~ data[,1] + data[,2])

Key difference:

Method Goal Uses Groups? Output
PCA Maximize variance No Principal Components
LDA Maximize separation Yes Discriminant Functions

The Mathematical Idea

LDA finds directions that:

  1. Maximize between-group variance (groups far apart)
  2. Minimize within-group variance (groups compact)
$$ \text{LD} = \arg\max \frac{\text{Between-group variance}}{\text{Within-group variance}} $$

Fisher's Linear Discriminant

For two groups, LDA finds the linear combination w that maximizes:

$$ J(w) = \frac{w^T S_B w}{w^T S_W w} $$

Where:

  • S_B = between-group scatter matrix
  • S_W = within-group scatter matrix
  • w = discriminant weights (loadings)

Intuition

Think of LDA as finding the best "projection line" where group means are far apart, but points within each group are close together.


LDA Assumptions

Critical Assumptions

LDA makes stronger assumptions than PCA:

  1. Multivariate normality - Each group follows a multivariate normal distribution
  2. Homogeneity of variance - All groups have the same covariance matrix
  3. Independence - Observations are independent
  4. No multicollinearity - Predictor variables not highly correlated

Checking Assumptions

# 1. Visual check of normality
library(ggplot2)
library(reshape2)

# Check each variable by group
data_long <- melt(data.frame(data, group = groups))
ggplot(data_long, aes(x = value, fill = group)) +
  geom_density(alpha = 0.5) +
  facet_wrap(~variable, scales = "free")

# 2. Test homogeneity of covariance (Box's M test)
library(biotools)
boxM(data, groups)
# Non-significant p-value = assumption met

# 3. Check multicollinearity
cor(data)
# High correlations (|r| > 0.8) indicate problems

What if Assumptions Are Violated?

  • Transform data (log, sqrt, Box-Cox)
  • Use Quadratic Discriminant Analysis (QDA) instead
  • Try logistic regression or random forests
  • Use QDA (allows different covariances per group)
  • Apply regularized LDA
  • Use robust methods
  • Regularized LDA (when p > n)
  • Reduce features with PCA first
  • Use cross-validation carefully

LDA in R: Complete Example

Basic Implementation

library(MASS)

# Use iris dataset
data(iris)
df <- iris[, 1:4]  # Predictors
species <- iris$Species  # Groups

# 1. Fit LDA model
lda_model <- lda(species ~ ., data = iris)

# 2. Examine output
print(lda_model)
#> Prior probabilities of groups:
#>     setosa versicolor  virginica 
#>      0.333      0.333      0.333
#> 
#> Group means: (shows centroids)
#> 
#> Coefficients of linear discriminants:
#> (shows loadings for each LD)

# 3. See how well groups separate
plot(lda_model)

Key Components

# Prior probabilities (can be set manually)
lda_model$prior

# Group means (centroids in original space)
lda_model$means

# Scaling (discriminant coefficients/loadings)
lda_model$scaling
#>                    LD1        LD2
#> Sepal.Length  0.8293776  0.02410215
#> Sepal.Width   1.5344731  2.16452123
#> Petal.Length -2.2012117 -0.93192121
#> Petal.Width  -2.8104603  2.83918785

# Proportion of trace (variance explained)
lda_model$svd^2 / sum(lda_model$svd^2)
#> LD1: 99.1%  LD2: 0.9%

Making Predictions

# Predict on training data
predictions <- predict(lda_model)

# Components:
# - $class: predicted group
# - $posterior: probability of belonging to each group
# - $x: coordinates in discriminant space

# Confusion matrix
table(Predicted = predictions$class, Actual = species)

# Classification accuracy
mean(predictions$class == species)
#> [1] 0.98  # 98% accuracy

Visualization

library(ggplot2)

# Extract LD scores
lda_data <- data.frame(
  LD1 = predictions$x[, 1],
  LD2 = predictions$x[, 2],
  Species = species
)

# Plot in discriminant space
ggplot(lda_data, aes(x = LD1, y = LD2, color = Species)) +
  geom_point(size = 3, alpha = 0.6) +
  stat_ellipse(level = 0.95) +
  labs(title = "LDA: Iris Species Separation",
       x = "LD1 (99.1% of separation)",
       y = "LD2 (0.9% of separation)") +
  theme_minimal()

Cross-Validation

Never Test on Training Data!

Always validate with independent data or cross-validation.

Leave-One-Out Cross-Validation (LOOCV)

# Built-in LOOCV in lda()
lda_cv <- lda(species ~ ., data = iris, CV = TRUE)

# Cross-validated predictions
table(Predicted = lda_cv$class, Actual = species)

# CV accuracy
mean(lda_cv$class == species)
#> [1] 0.98

k-Fold Cross-Validation

library(caret)

# Set up 10-fold CV
train_control <- trainControl(method = "cv", number = 10)

# Train with CV
set.seed(123)
cv_model <- train(Species ~ ., 
                  data = iris,
                  method = "lda",
                  trControl = train_control)

# Results
print(cv_model)
#> Accuracy: 0.98 (averaged over 10 folds)

Train-Test Split

# Split data
set.seed(123)
train_idx <- sample(1:nrow(iris), 0.7 * nrow(iris))
train_data <- iris[train_idx, ]
test_data <- iris[-train_idx, ]

# Train on training set
lda_train <- lda(Species ~ ., data = train_data)

# Predict on test set
test_pred <- predict(lda_train, newdata = test_data)

# Test accuracy
mean(test_pred$class == test_data$Species)

Interpretation

1. Discriminant Loadings

Which variables drive separation?

# Get loadings
loadings <- lda_model$scaling

# Visualize
library(reshape2)
loadings_long <- melt(loadings)
names(loadings_long) <- c("Variable", "LD", "Loading")

ggplot(loadings_long, aes(x = Variable, y = Loading, fill = LD)) +
  geom_bar(stat = "identity", position = "dodge") +
  coord_flip() +
  labs(title = "LDA Loadings") +
  theme_minimal()

Interpretation:

  • Large absolute values → variable strongly influences that LD
  • Positive loading → higher values push toward positive LD scores
  • Negative loading → higher values push toward negative LD scores

2. Group Centroids

Where are groups located in discriminant space?

# Calculate group means in LD space
group_means <- predictions$x %>%
  as.data.frame() %>%
  mutate(Species = species) %>%
  group_by(Species) %>%
  summarise(across(everything(), mean))

print(group_means)

3. Posterior Probabilities

How confident are predictions?

# Posterior probabilities
head(predictions$posterior)
#>      setosa versicolor virginica
#> [1,]  1.000      0.000     0.000
#> [2,]  1.000      0.000     0.000
#> [3,]  0.998      0.002     0.000

# Plot confidence
library(tidyverse)
post_df <- as.data.frame(predictions$posterior) %>%
  mutate(Sample = 1:n(),
         True_Species = species,
         Predicted = predictions$class)

# Samples with low confidence
uncertain <- post_df %>%
  rowwise() %>%
  mutate(Max_Prob = max(c_across(setosa:virginica))) %>%
  filter(Max_Prob < 0.9)

print(uncertain)

Biological Examples

Example 1: Cancer Classification

library(MASS)

# Simulated gene expression data
set.seed(42)
n_genes <- 20
n_samples <- 60

# Create data: 3 cancer subtypes
gene_expr <- matrix(rnorm(n_genes * n_samples), ncol = n_genes)
gene_expr[1:20, 1:5] <- gene_expr[1:20, 1:5] + 2      # Subtype 1
gene_expr[21:40, 6:10] <- gene_expr[21:40, 6:10] + 2  # Subtype 2
gene_expr[41:60, 11:15] <- gene_expr[41:60, 11:15] + 2 # Subtype 3

subtypes <- factor(rep(c("TypeA", "TypeB", "TypeC"), each = 20))

# LDA
lda_cancer <- lda(subtypes ~ gene_expr)

# Cross-validation
lda_cv <- lda(subtypes ~ gene_expr, CV = TRUE)
table(Predicted = lda_cv$class, Actual = subtypes)

# Identify discriminant genes
loadings <- lda_cancer$scaling
top_genes <- apply(abs(loadings), 2, function(ld) {
  names(sort(ld, decreasing = TRUE)[1:5])
})
print(top_genes)

Example 2: Microbiome Disease Classification

library(phyloseq)
library(MASS)

# Example with GlobalPatterns dataset
data(GlobalPatterns)
gp <- GlobalPatterns

# Filter and transform
gp_filtered <- filter_taxa(gp, function(x) sum(x > 3) > (0.1 * length(x)), TRUE)
gp_clr <- microbiome::transform(gp_filtered, "clr")

# Extract data
otu_clr <- t(as(otu_table(gp_clr), "matrix"))
sample_type <- get_variable(gp_clr, "SampleType")

# LDA
lda_micro <- lda(sample_type ~ otu_clr)

# Predict
pred <- predict(lda_micro)

# Visualize
library(ggplot2)
lda_plot_data <- data.frame(
  LD1 = pred$x[, 1],
  LD2 = if(ncol(pred$x) > 1) pred$x[, 2] else 0,
  SampleType = sample_type
)

ggplot(lda_plot_data, aes(x = LD1, y = LD2, color = SampleType)) +
  geom_point(size = 3) +
  theme_minimal()

LDA vs. Other Methods

Method Type Uses Groups? Assumes Equal Cov? Best For
LDA Linear Yes Yes Balanced groups, normal data
QDA Quadratic Yes No Groups with different spreads
PCA Linear No N/A Exploration, dimension reduction
Logistic Regression Linear Yes No Binary classification, interpret coefs
Random Forest Nonlinear Yes N/A Complex patterns, no assumptions

When to Use Each

  • Groups have similar covariances
  • Sample size is moderate (n > 5p)
  • Data is approximately normal
  • You want interpretable discriminants
  • Classification is the goal
  • Groups have different covariances
  • More flexible boundaries needed
  • Enough data per group (n > 10p)
  • LDA assumption violations
  • No group labels
  • Want to explore variance structure
  • Preprocessing for other methods
  • Dimension reduction needed
  • Very small sample size → regularized methods
  • Nonlinear patterns → kernel methods, RF
  • Binary outcome → logistic regression
  • Severe assumption violations → non-parametric

Common Pitfalls

Mistakes to Avoid

1. Not checking assumptions

# BAD: Just fit without checking
lda(groups ~ ., data = data)

# GOOD: Check first
boxM(data, groups)  # Homogeneity
# Visual check for normality

2. Too many variables for sample size

# BAD: p > n (more variables than samples)
# LDA will overfit!

# GOOD: Reduce dimensions first
pca <- prcomp(data, scale. = TRUE)
pca_data <- pca$x[, 1:10]  # Keep 10 PCs
lda(groups ~ pca_data)

3. Not scaling data

# BAD: Variables on different scales
lda(groups ~ height_cm + weight_kg, data = data)

# GOOD: Scale first or use formula
data_scaled <- scale(data)
lda(groups ~ ., data = data.frame(data_scaled, groups))

4. Using training accuracy only

# BAD: Overly optimistic
pred <- predict(lda_model)
mean(pred$class == groups)  # ~100%!

# GOOD: Use cross-validation
lda_cv <- lda(groups ~ ., data = data, CV = TRUE)
mean(lda_cv$class == groups)  # Realistic estimate

5. Unbalanced groups

# BAD: Group A (n=100), Group B (n=10)
# LDA will favor Group A

# GOOD: Adjust priors or balance data
lda(groups ~ ., data = data, 
    prior = c(0.5, 0.5))  # Equal priors


Practical Workflow

Complete LDA analysis workflow:

library(MASS)
library(caret)

# 1. Prepare data
data <- your_data
groups <- your_groups

# 2. Check assumptions
# Normality (visual)
pairs(data, col = groups)

# Homogeneity of variance
library(biotools)
boxM(data, groups)

# Multicollinearity
cor(data)

# 3. Split data
set.seed(123)
train_idx <- createDataPartition(groups, p = 0.7, list = FALSE)
train_data <- data[train_idx, ]
train_groups <- groups[train_idx]
test_data <- data[-train_idx, ]
test_groups <- groups[-train_idx]

# 4. Fit LDA
lda_model <- lda(train_groups ~ ., data = train_data)

# 5. Cross-validate
cv_results <- lda(train_groups ~ ., data = train_data, CV = TRUE)
cv_accuracy <- mean(cv_results$class == train_groups)
cat("CV Accuracy:", cv_accuracy, "\n")

# 6. Test on holdout set
test_pred <- predict(lda_model, newdata = test_data)
test_accuracy <- mean(test_pred$class == test_groups)
cat("Test Accuracy:", test_accuracy, "\n")

# 7. Confusion matrix
confusionMatrix(test_pred$class, test_groups)

# 8. Visualize
lda_scores <- data.frame(
  LD1 = test_pred$x[, 1],
  LD2 = if(ncol(test_pred$x) > 1) test_pred$x[, 2] else 0,
  Group = test_groups,
  Predicted = test_pred$class
)

library(ggplot2)
ggplot(lda_scores, aes(x = LD1, y = LD2, color = Group, shape = Predicted)) +
  geom_point(size = 3) +
  theme_minimal()

# 9. Interpret loadings
loadings <- lda_model$scaling
print(loadings)

Exercises

Practice Problems

Exercise 1: Basic LDA

Use the iris dataset to classify species:

  1. Perform 10-fold cross-validation
  2. Calculate accuracy
  3. Which species is hardest to classify?
  4. Plot samples in LD1-LD2 space
Solution
library(MASS)
library(caret)

# 1. 10-fold CV
train_control <- trainControl(method = "cv", number = 10)
cv_model <- train(Species ~ ., 
                 data = iris,
                 method = "lda",
                 trControl = train_control)

# 2. Accuracy
print(cv_model)

# 3. Confusion matrix
lda_cv <- lda(Species ~ ., data = iris, CV = TRUE)
table(lda_cv$class, iris$Species)
# Virginica and Versicolor overlap most

# 4. Plot
lda_model <- lda(Species ~ ., data = iris)
pred <- predict(lda_model)

plot_data <- data.frame(
  LD1 = pred$x[, 1],
  LD2 = pred$x[, 2],
  Species = iris$Species
)

library(ggplot2)
ggplot(plot_data, aes(x = LD1, y = LD2, color = Species)) +
  geom_point(size = 3) +
  stat_ellipse()

Exercise 2: Comparing LDA and PCA

Using iris data:

  1. Perform both LDA and PCA
  2. Plot samples in both spaces
  3. Which method separates species better?
  4. Why?
Solution
# LDA
lda_model <- lda(Species ~ ., data = iris)
lda_pred <- predict(lda_model)

# PCA
pca_model <- prcomp(iris[, 1:4], scale. = TRUE)

# Plot LDA
plot1 <- data.frame(
  LD1 = lda_pred$x[, 1],
  LD2 = lda_pred$x[, 2],
  Species = iris$Species
)

p1 <- ggplot(plot1, aes(LD1, LD2, color = Species)) +
  geom_point() + ggtitle("LDA") + theme_minimal()

# Plot PCA
plot2 <- data.frame(
  PC1 = pca_model$x[, 1],
  PC2 = pca_model$x[, 2],
  Species = iris$Species
)

p2 <- ggplot(plot2, aes(PC1, PC2, color = Species)) +
  geom_point() + ggtitle("PCA") + theme_minimal()

library(gridExtra)
grid.arrange(p1, p2, ncol = 2)

# LDA separates better because it uses group information

Key Takeaways

Remember These Concepts

LDA is supervised - uses group labels to find discriminant directions
Maximizes separation between groups while minimizing within-group variance
Assumes normality and homogeneity of covariance matrices
Number of LDs = min(p, g-1) where g is number of groups
Always validate with cross-validation or independent test set
Scale your data unless variables are already on similar scales
✓ Use QDA when covariances differ between groups
Interpret loadings to understand which variables drive separation


Further Resources

Textbooks

  • Hastie, T. et al. (2009). The Elements of Statistical Learning. Chapter 4.
  • Venables, W.N. & Ripley, B.D. (2002). Modern Applied Statistics with S. Chapter 12.

R Packages

  • MASS - lda() and qda() functions
  • caret - Unified interface with CV
  • klaR - Classification visualization

Online Resources


When to Use What

  • Explore patterns without groups → PCA
  • Find natural groups without labels → Clustering
  • Classify with labels → LDA (or QDA, logistic regression)