Introduction
FlashAttention is an efficient implementation of exact attention that computes the attention matrix in blocks to avoid materializing the full n×n matrix in GPU memory. It achieves the same result as standard attention but with O(n) memory instead of O(n²) by using tiling and careful memory management.
The Problem
Standard attention computes the full attention matrix, requiring O(n²) memory:
For n=4096: 16M entries ≈ 64MB (fp16)
For n=16384: 268M entries ≈ 1GB
Memory bottleneck limits maximum sequence length
For n=16384: 268M entries ≈ 1GB
Memory bottleneck limits maximum sequence length
Solution: Tiled Computation
FlashAttention tiles the computation into blocks that fit in SRAM (fast on-chip memory):
GPU HBM (slow) ←→ SRAM (fast, limited)
Standard: Load all Q, K, V from HBM, compute, store full matrix
FlashAttention: Load tiles, compute in SRAM, accumulate result
Algorithm Steps
1. Load Q in blocks
2. For each block of Q:
- Load block of K, V from HBM
- Compute partial attention scores
- Update running softmax statistics
3. Use online softmax to compute correct normalization
2. For each block of Q:
- Load block of K, V from HBM
- Compute partial attention scores
- Update running softmax statistics
3. Use online softmax to compute correct normalization
Online Softmax Trick
To compute correct softmax while only having partial information:
Standard: softmax(x) = exp(x) / Σ exp(x)
Online (FlashAttention):
m = max(prev_m, new_m)
softmax = Σ exp(x - m) / Σ exp(x - m)
This allows incremental computation
Online (FlashAttention):
m = max(prev_m, new_m)
softmax = Σ exp(x - m) / Σ exp(x - m)
This allows incremental computation
Speedup and Memory Savings
| Aspect | Standard | FlashAttention |
|---|---|---|
| Memory | O(n²) | O(n) [no n² matrix stored] |
| Speed | 1× baseline | 2-4× faster |
| Output | Exact | Bit-exact equivalent |
Versions
- FlashAttention-1: Basic tiled computation
- FlashAttention-2: Better parallelism, better register usage
- FlashAttention-3: Further optimizations for H100 GPUs