40. Memory-Efficient Attention

Introduction

Memory-efficient attention refers to techniques that reduce the memory footprint of attention computation. Since standard attention requires O(n²) memory for the attention matrix, various methods have been developed to work within limited memory constraints.

Memory Bottleneck

Attention matrix: n × n entries
For n=4096, d=64: 16M entries ≈ 64MB
For n=16384: 268M entries ≈ 1GB just for attention!

Gradient checkpointing: need forward + backward
Training: multiple heads × layers = huge memory

Memory-Efficient Techniques

1. Gradient Checkpointing (Activation Recomputation)

Store activations during forward pass
Recompute during backward (trades compute for memory)

Memory: O(n) instead of O(n·d) per layer
Compute: ~2× more FLOPs

2. Block-wise Attention

Process attention in blocks
Only store block statistics, not full matrix

FlashAttention uses this approach

3. Inverted Linear Attention

Instead of: Attention = softmax(QKᵀ)V

Reformulate to avoid n×n matrix:
Maintain running statistics of KV products

Memory Comparison

MethodMemoryTradeoff
StandardO(n²)High memory, low compute
Gradient CheckpointingO(n·d)2× compute, less memory
FlashAttentionO(n)More compute, minimal memory
Sparse/Linear AttentionO(n) or lessApproximation, much less memory

Implementation Strategies

Attention with KV Cache Optimization

In inference, cache K and V to avoid recomputation:

For generation step t:
K_cache = [K_0, K_1, ..., K_{t-1}]
V_cache = [V_0, V_1, ..., V_{t-1}]

Attention = softmax(Q_t K_cacheᵀ) V_cache

PagedAttention (vLLM)

Manage KV cache like virtual memory pages for efficient batching.

Test Your Understanding

Question 1: Standard attention matrix for n=4096 requires memory for:

  • A) 4096 entries
  • B) 16M entries
  • C) 64 entries
  • D) 4096² entries only

Question 2: Gradient checkpointing trades:

  • A) Memory for compute
  • B) Compute for memory
  • C) Accuracy for speed
  • D) Speed for accuracy

Question 3: FlashAttention achieves O(n) memory by:

  • A) Approximating the result
  • B) Not storing full matrix, using block-wise computation
  • C) Reducing accuracy
  • D) Using different activation

Question 4: During inference, KV cache stores:

  • A) Q matrices
  • B) K and V matrices for generated tokens
  • C> No cache
  • D) Only Q and K

Question 5: PagedAttention optimizes:

  • A) Training memory
  • B) KV cache management like virtual memory pages
  • C) Attention accuracy
  • D) Model architecture

Question 6: With gradient checkpointing, memory is reduced to approximately:

  • A) O(n²)
  • B) O(n·d)
  • C) O(d²)
  • D) O(1)

Question 7: Memory-efficient attention is critical for:

  • A) Short sequences only
  • B) Long sequence processing and training
  • C) Only inference
  • D) Only vision transformers

Question 8: Block-wise attention computes:

  • A) Full matrix at once
  • B> One block at a time, accumulating statistics
  • C) Random blocks
  • D) Only diagonal