55. Grouped Query Attention (GQA)

Introduction

Grouped Query Attention (GQA) is a variant of multi-head attention where multiple query heads share a single key-value head. This reduces the KV cache size and memory bandwidth requirements while maintaining good performance.

Standard MHA vs GQA

Multi-Head Attention (MHA)

n_heads_q query heads
n_heads_kv key-value heads If n_heads_q = n_heads_kv = h:
Each head has its own K, V projections

Grouped Query Attention

n_heads_q query heads (more)
n_heads_kv key-value heads (fewer, e.g., h/4 or h/8) Each K,V head is shared by multiple Q heads Query: Q_0, Q_1, ..., Q_{n_q-1}
Key/Value: K_0, K_1, ..., K_{n_kv-1} where n_kv < n_q

Example: Llama 2 Configuration

n_heads_q = 32 (queries)
n_heads_kv = 8 (key-value heads) Each KV head serves 4 query heads KV cache reduced by 4× compared to MHA

Memory and Speed Benefits

AspectMHAGQA (8 KV heads)
KV heads328
KV cache size32 heads8 heads (4× smaller)
Memory bandwidthHighReduced

GQA Formula

Q ∈ ℝ^{seq_len × n_q·d}
K, V ∈ ℝ^{seq_len × n_kv·d} Each query head attends to all KV heads:
Attention = concat(Q_heads) @ concat(KV_heads)ᵀ Where each group of n_q/n_kv queries shares one KV projection

Training vs Inference

Test Your Understanding

Question 1: In GQA, multiple query heads:

  • A) Each have unique K,V
  • B) Share a single key-value head
  • C) No query heads
  • D) Random sharing

Question 2: With 32 query heads and 8 KV heads, each KV head serves:

  • A) 1 query head
  • B) 4 query heads
  • C) 32 query heads
  • D> 8 query heads

Question 3: GQA reduces:

  • A) Model accuracy
  • B) KV cache size and memory bandwidth
  • C) Number of parameters
  • D> Nothing

Question 4: Compared to MHA with 32 heads, GQA with 8 KV heads has KV cache:

  • A) Same size
  • B) 4× smaller
  • C) 4× larger
  • D) 8× smaller

Question 5: GQA is used in:

  • A) BERT
  • B) Llama 2
  • C) Original Transformer
  • D) ResNet

Question 6: GQA provides benefits mainly during:

  • A) Training
  • B) Inference (smaller KV cache, less memory traffic)
  • C) Neither
  • D) Both equally

Question 7: If n_heads_q = 48 and n_heads_kv = 12, the sharing factor is:

  • A) 48
  • B) 12
  • C) 4
  • D) 1

Question 8: GQA reduces memory bandwidth because:

  • A) Fewer KV heads need to be loaded per attention step
  • B) More computation
  • C) KV cache is larger
  • D) No reason