Introduction
Memory-efficient attention refers to techniques that reduce the memory footprint of attention computation. Since standard attention requires O(n²) memory for the attention matrix, various methods have been developed to work within limited memory constraints.
Memory Bottleneck
Attention matrix: n × n entries
For n=4096, d=64: 16M entries ≈ 64MB
For n=16384: 268M entries ≈ 1GB just for attention!
Gradient checkpointing: need forward + backward
Training: multiple heads × layers = huge memory
For n=4096, d=64: 16M entries ≈ 64MB
For n=16384: 268M entries ≈ 1GB just for attention!
Gradient checkpointing: need forward + backward
Training: multiple heads × layers = huge memory
Memory-Efficient Techniques
1. Gradient Checkpointing (Activation Recomputation)
Store activations during forward pass
Recompute during backward (trades compute for memory)
Memory: O(n) instead of O(n·d) per layer
Compute: ~2× more FLOPs
Recompute during backward (trades compute for memory)
Memory: O(n) instead of O(n·d) per layer
Compute: ~2× more FLOPs
2. Block-wise Attention
Process attention in blocks
Only store block statistics, not full matrix
FlashAttention uses this approach
Only store block statistics, not full matrix
FlashAttention uses this approach
3. Inverted Linear Attention
Instead of: Attention = softmax(QKᵀ)V
Reformulate to avoid n×n matrix:
Maintain running statistics of KV products
Reformulate to avoid n×n matrix:
Maintain running statistics of KV products
Memory Comparison
| Method | Memory | Tradeoff |
|---|---|---|
| Standard | O(n²) | High memory, low compute |
| Gradient Checkpointing | O(n·d) | 2× compute, less memory |
| FlashAttention | O(n) | More compute, minimal memory |
| Sparse/Linear Attention | O(n) or less | Approximation, much less memory |
Implementation Strategies
Attention with KV Cache Optimization
In inference, cache K and V to avoid recomputation:
For generation step t:
K_cache = [K_0, K_1, ..., K_{t-1}]
V_cache = [V_0, V_1, ..., V_{t-1}]
Attention = softmax(Q_t K_cacheᵀ) V_cache
K_cache = [K_0, K_1, ..., K_{t-1}]
V_cache = [V_0, V_1, ..., V_{t-1}]
Attention = softmax(Q_t K_cacheᵀ) V_cache
PagedAttention (vLLM)
Manage KV cache like virtual memory pages for efficient batching.