12. Multi-Head Attention

Introduction

Multi-head attention is an extension of scaled dot-product attention that allows the model to jointly attend to information from different representation subspaces. Instead of performing a single attention function, the model learns multiple "heads" that attend to different aspects of the sequence, and their outputs are combined.

Core Concept

Instead of doing one attention computation with d_model-dimensional keys/values/queries, we project them into lower dimensions and attend h times in parallel:

MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) · W⁰

where headᵢ = Attention(QWᵠᵢ, KWₖᵢ, VWᵥᵢ)
Multi-Head Attention Architecture: Input X (d_model) │ ├──▶ Linear Q ──▶ Split into h heads ──▶ Q₁, Q₂, ..., Qₕ (each dₖ) │ ├──▶ Linear K ──▶ Split into h heads ──▶ K₁, K₂, ..., Kₕ (each dₖ) │ └──▶ Linear V ──▶ Split into h heads ──▶ V₁, V₂, ..., Vₕ (each dᵥ) Then for each head i: headᵢ = Attention(Qᵢ, Kᵢ, Vᵢ) ∈ ℝ^{seq_len × dᵥ} Concat all heads: [head₁, head₂, ..., headₕ] ∈ ℝ^{seq_len × h·dᵥ} │ └──▶ Linear W⁰ ──▶ Output ∈ ℝ^{seq_len × d_model}

Hyperparameters

d_model = 512, h = 8 → dₖ = dᵥ = 64

d_model = 768, h = 12 → dₖ = dᵥ = 64

d_model = 1024, h = 16 → dₖ = dᵥ = 64

Why Multiple Heads?

Different attention heads can learn to focus on different types of relationships:

Detailed Algorithm

Step 1: Linear Projections

Q = X · Wᵠ ∈ ℝ^{seq_len × d_model}

K = X · Wₖ ∈ ℝ^{seq_len × d_model}

V = X · Wᵥ ∈ ℝ^{seq_len × d_model}

Step 2: Reshape for Heads

Q reshaped: [seq_len, h, dₖ] → [h, seq_len, dₖ]

K reshaped: [seq_len, h, dₖ] → [h, seq_len, dₖ]

V reshaped: [seq_len, h, dᵥ] → [h, seq_len, dᵥ]

Step 3: Parallel Attention

For each head hᵢ:
headᵢ = Attention(Qᵢ, Kᵢ, Vᵢ) = softmax(QᵢKᵢᵀ/√dₖ) · Vᵢ

Step 4: Concatenate and Project

heads = [head₁, head₂, ..., headₕ] ∈ ℝ^{seq_len × h·dᵥ}

output = heads · W⁰ ∈ ℝ^{seq_len × d_model}

Parameter Count

Per head: Wᵠᵢ ∈ ℝ^{d_model × dₖ}, Wₖᵢ ∈ ℝ^{d_model × dₖ}, Wᵥᵢ ∈ ℝ^{d_model × dᵥ}

All heads combined: Wᵠ, Wₖ, Wᵥ ∈ ℝ^{d_model × d_model}

Output projection: W⁰ ∈ ℝ^{h·dᵥ × d_model} = ℝ^{d_model × d_model}

Total: 4 × d_model² (same as single-head with same dimensions)

Key Properties

Variations

Original Transformer

d_model=512, h=8, dₖ=dᵥ=64

BERT Base

d_model=768, h=12, dₖ=dᵥ=64

BERT Large

d_model=1024, h=16, dₖ=dᵥ=64

GPT-2

d_model=768/1024/1280/1600, h=12/16/16/20, dₖ=64 (always)

Test Your Understanding

Question 1: What does multi-head attention allow the model to do?

  • A) Use more parameters
  • B) Attend to information from different representation subspaces
  • C) Process sequences faster
  • D) Use less memory

Question 2: If d_model = 512 and h = 8, what is dₖ (dimension per head)?

  • A) 8
  • B) 64
  • C) 512
  • D) 4096

Question 3: What operation combines the outputs of all attention heads?

  • A) Average
  • B> Concatenation followed by linear projection
  • C) Max pooling
  • D) Sum

Question 4: How does the computational cost of multi-head attention compare to single-head?

  • A) Much higher
  • B) Much lower
  • C) Approximately the same
  • D) Depends on sequence length

Question 5: Each attention head can potentially learn different:

  • A) Loss functions
  • B) Types of relationships (syntactic, semantic, etc.)
  • C) Input modalities
  • D) Model layers

Question 6: In BERT Base, if d_model=768 and h=12, what is dₖ?

  • A) 768
  • B) 12
  • C) 64
  • D) 57

Question 7: Why do we use dₖ = d_model / h?

  • A) To increase parameters
  • B> To keep total computation similar to single attention
  • C) To reduce accuracy
  • D) To speed up softmax

Question 8: What is the formula for MultiHead attention?

  • A) Attention(Q, K, V) = softmax(QKᵀ)V
  • B) MultiHead = Concat(head₁...headₕ)W⁰ where headᵢ = Attention(QWᵠᵢ, KWₖᵢ, VWᵥᵢ)
  • C) MultiHead = Average(Attention for each head)
  • D) MultiHead = Sum(Attention for each head)