Introduction
Grouped Query Attention (GQA) is a variant of multi-head attention where multiple query heads share a single key-value head. This reduces the KV cache size and memory bandwidth requirements while maintaining good performance.
Standard MHA vs GQA
Multi-Head Attention (MHA)
n_heads_q query heads
n_heads_kv key-value heads
If n_heads_q = n_heads_kv = h:
Each head has its own K, V projections
n_heads_kv key-value heads
If n_heads_q = n_heads_kv = h:
Each head has its own K, V projections
Grouped Query Attention
n_heads_q query heads (more)
n_heads_kv key-value heads (fewer, e.g., h/4 or h/8)
Each K,V head is shared by multiple Q heads
Query: Q_0, Q_1, ..., Q_{n_q-1}
Key/Value: K_0, K_1, ..., K_{n_kv-1} where n_kv < n_q
n_heads_kv key-value heads (fewer, e.g., h/4 or h/8)
Each K,V head is shared by multiple Q heads
Query: Q_0, Q_1, ..., Q_{n_q-1}
Key/Value: K_0, K_1, ..., K_{n_kv-1} where n_kv < n_q
Example: Llama 2 Configuration
n_heads_q = 32 (queries)
n_heads_kv = 8 (key-value heads)
Each KV head serves 4 query heads
KV cache reduced by 4× compared to MHA
n_heads_kv = 8 (key-value heads)
Each KV head serves 4 query heads
KV cache reduced by 4× compared to MHA
Memory and Speed Benefits
| Aspect | MHA | GQA (8 KV heads) |
|---|---|---|
| KV heads | 32 | 8 |
| KV cache size | 32 heads | 8 heads (4× smaller) |
| Memory bandwidth | High | Reduced |
GQA Formula
Q ∈ ℝ^{seq_len × n_q·d}
K, V ∈ ℝ^{seq_len × n_kv·d}
Each query head attends to all KV heads:
Attention = concat(Q_heads) @ concat(KV_heads)ᵀ
Where each group of n_q/n_kv queries shares one KV projection
K, V ∈ ℝ^{seq_len × n_kv·d}
Each query head attends to all KV heads:
Attention = concat(Q_heads) @ concat(KV_heads)ᵀ
Where each group of n_q/n_kv queries shares one KV projection
Training vs Inference
- Training: All Q heads computed, but K,V reduced
- Inference: Major benefit from smaller KV cache and memory traffic