Introduction
Multi-head attention is an extension of scaled dot-product attention that allows the model to jointly attend to information from different representation subspaces. Instead of performing a single attention function, the model learns multiple "heads" that attend to different aspects of the sequence, and their outputs are combined.
Core Concept
Instead of doing one attention computation with d_model-dimensional keys/values/queries, we project them into lower dimensions and attend h times in parallel:
MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) · W⁰
where headᵢ = Attention(QWᵠᵢ, KWₖᵢ, VWᵥᵢ)
where headᵢ = Attention(QWᵠᵢ, KWₖᵢ, VWᵥᵢ)
Multi-Head Attention Architecture:
Input X (d_model)
│
├──▶ Linear Q ──▶ Split into h heads ──▶ Q₁, Q₂, ..., Qₕ (each dₖ)
│
├──▶ Linear K ──▶ Split into h heads ──▶ K₁, K₂, ..., Kₕ (each dₖ)
│
└──▶ Linear V ──▶ Split into h heads ──▶ V₁, V₂, ..., Vₕ (each dᵥ)
Then for each head i:
headᵢ = Attention(Qᵢ, Kᵢ, Vᵢ) ∈ ℝ^{seq_len × dᵥ}
Concat all heads: [head₁, head₂, ..., headₕ] ∈ ℝ^{seq_len × h·dᵥ}
│
└──▶ Linear W⁰ ──▶ Output ∈ ℝ^{seq_len × d_model}
Hyperparameters
- h: Number of attention heads (typically 8, 12, 16)
- dₖ: Dimension per head = d_model / h
- dᵥ: Dimension per head = d_model / h (often same as dₖ)
d_model = 512, h = 8 → dₖ = dᵥ = 64
d_model = 768, h = 12 → dₖ = dᵥ = 64
d_model = 1024, h = 16 → dₖ = dᵥ = 64
d_model = 768, h = 12 → dₖ = dᵥ = 64
d_model = 1024, h = 16 → dₖ = dᵥ = 64
Why Multiple Heads?
Different attention heads can learn to focus on different types of relationships:
- Syntactic relationships: Subject-verb agreement
- Semantic relationships: Word meaning similarity
- Coreference: Pronoun resolution to noun
- Long-range dependencies: Clause boundaries
- Positional patterns: Sequential vs position-independent
Detailed Algorithm
Step 1: Linear Projections
Q = X · Wᵠ ∈ ℝ^{seq_len × d_model}
K = X · Wₖ ∈ ℝ^{seq_len × d_model}
V = X · Wᵥ ∈ ℝ^{seq_len × d_model}
K = X · Wₖ ∈ ℝ^{seq_len × d_model}
V = X · Wᵥ ∈ ℝ^{seq_len × d_model}
Step 2: Reshape for Heads
Q reshaped: [seq_len, h, dₖ] → [h, seq_len, dₖ]
K reshaped: [seq_len, h, dₖ] → [h, seq_len, dₖ]
V reshaped: [seq_len, h, dᵥ] → [h, seq_len, dᵥ]
K reshaped: [seq_len, h, dₖ] → [h, seq_len, dₖ]
V reshaped: [seq_len, h, dᵥ] → [h, seq_len, dᵥ]
Step 3: Parallel Attention
For each head hᵢ:
headᵢ = Attention(Qᵢ, Kᵢ, Vᵢ) = softmax(QᵢKᵢᵀ/√dₖ) · Vᵢ
headᵢ = Attention(Qᵢ, Kᵢ, Vᵢ) = softmax(QᵢKᵢᵀ/√dₖ) · Vᵢ
Step 4: Concatenate and Project
heads = [head₁, head₂, ..., headₕ] ∈ ℝ^{seq_len × h·dᵥ}
output = heads · W⁰ ∈ ℝ^{seq_len × d_model}
output = heads · W⁰ ∈ ℝ^{seq_len × d_model}
Parameter Count
Per head: Wᵠᵢ ∈ ℝ^{d_model × dₖ}, Wₖᵢ ∈ ℝ^{d_model × dₖ}, Wᵥᵢ ∈ ℝ^{d_model × dᵥ}
All heads combined: Wᵠ, Wₖ, Wᵥ ∈ ℝ^{d_model × d_model}
Output projection: W⁰ ∈ ℝ^{h·dᵥ × d_model} = ℝ^{d_model × d_model}
Total: 4 × d_model² (same as single-head with same dimensions)
All heads combined: Wᵠ, Wₖ, Wᵥ ∈ ℝ^{d_model × d_model}
Output projection: W⁰ ∈ ℝ^{h·dᵥ × d_model} = ℝ^{d_model × d_model}
Total: 4 × d_model² (same as single-head with same dimensions)
Key Properties
- Same computational cost: Despite multiple heads, total computation is similar to single attention with d_model dimensions
- No additional learnable parameters: Parameter count stays roughly the same
- Joint attention: Allows capturing different types of relationships simultaneously
Variations
Original Transformer
d_model=512, h=8, dₖ=dᵥ=64
BERT Base
d_model=768, h=12, dₖ=dᵥ=64
BERT Large
d_model=1024, h=16, dₖ=dᵥ=64
GPT-2
d_model=768/1024/1280/1600, h=12/16/16/20, dₖ=64 (always)