39. FlashAttention

Introduction

FlashAttention is an efficient implementation of exact attention that computes the attention matrix in blocks to avoid materializing the full n×n matrix in GPU memory. It achieves the same result as standard attention but with O(n) memory instead of O(n²) by using tiling and careful memory management.

The Problem

Standard attention computes the full attention matrix, requiring O(n²) memory:

For n=4096: 16M entries ≈ 64MB (fp16)
For n=16384: 268M entries ≈ 1GB

Memory bottleneck limits maximum sequence length

Solution: Tiled Computation

FlashAttention tiles the computation into blocks that fit in SRAM (fast on-chip memory):

GPU HBM (slow) ←→ SRAM (fast, limited) Standard: Load all Q, K, V from HBM, compute, store full matrix FlashAttention: Load tiles, compute in SRAM, accumulate result

Algorithm Steps

1. Load Q in blocks

2. For each block of Q:
- Load block of K, V from HBM
- Compute partial attention scores
- Update running softmax statistics

3. Use online softmax to compute correct normalization

Online Softmax Trick

To compute correct softmax while only having partial information:

Standard: softmax(x) = exp(x) / Σ exp(x)

Online (FlashAttention):
m = max(prev_m, new_m)
softmax = Σ exp(x - m) / Σ exp(x - m)

This allows incremental computation

Speedup and Memory Savings

AspectStandardFlashAttention
MemoryO(n²)O(n) [no n² matrix stored]
Speed1× baseline2-4× faster
OutputExactBit-exact equivalent

Versions

Test Your Understanding

Question 1: FlashAttention achieves O(n) memory by:

  • A) Approximating attention
  • B) Not storing full n×n matrix, using tiled computation
  • C) Reducing accuracy
  • D) Using less compute

Question 2: FlashAttention tiles computation to fit in:

  • A) HBM (slow GPU memory)
  • B) SRAM (fast on-chip memory)
  • C) CPU cache
  • D) Disk

Question 3: FlashAttention output is:

  • A) Approximation
  • B) Bit-exact equivalent to standard attention
  • C) Random
  • D) Lower quality

Question 4: For n=4096, standard attention stores how many entries?

  • A) 4096
  • B) 16M
  • C) 64
  • D) 4096²

Question 5: The "online softmax" in FlashAttention allows:

  • A) Approximation
  • B> Correct softmax normalization with partial information
  • C) Slower computation
  • D) No benefit

Question 6: FlashAttention speedup over standard is approximately:

  • A) 1× (same speed)
  • B) 2-4× faster
  • C) 10× slower
  • D) 100× faster

Question 7: FlashAttention requires extra computation compared to standard but is faster because:

  • A) It approximates
  • B) It avoids loading full matrices from slow HBM, computing in fast SRAM
  • C) It uses less parameters
  • D) It uses less accurate math

Question 8: Which is NOT a version of FlashAttention?

  • A) FlashAttention-1
  • B) FlashAttention-2
  • C) FlashAttention-3
  • D) FlashAttention-Sparse