Introduction
Ring attention is a distributed attention mechanism where the attention computation is split across multiple devices arranged in a ring topology. Each device handles a slice of the sequence, and information is passed around the ring to compute full attention.
Motivation
Even with linear attention, long sequences require too much memory for a single device. Ring attention distributes the computation across multiple GPUs/NPUs.
Ring Topology
Devices arranged in ring:
Device 0 → Device 1 → Device 2 → ... → Device N-1 → Device 0
Each device holds a chunk of Q, K, V
Communication passes K, V around the ring
Algorithm
For each step in ring:
1. Device i has Q_i (its chunk)
2. K_i, V_i are sent to device i-1
3. Device i computes partial attention with received K, V
4. Accumulate and repeat for all devices
After n steps (n devices), each device has full attention output
1. Device i has Q_i (its chunk)
2. K_i, V_i are sent to device i-1
3. Device i computes partial attention with received K, V
4. Accumulate and repeat for all devices
After n steps (n devices), each device has full attention output
Communication Pattern
Step 0: Device 0 has K_0, V_0
Step 1: K_0, V_0 moves to Device N-1, Device 0 gets K_1, V_1
Step 2: K_1, V_1 moves to Device N-2, and so on...
All devices exchange K, V in ring fashion
Step 1: K_0, V_0 moves to Device N-1, Device 0 gets K_1, V_1
Step 2: K_1, V_1 moves to Device N-2, and so on...
All devices exchange K, V in ring fashion
Properties
| Aspect | Ring Attention |
|---|---|
| Communication | K, V passed in ring (all-to-all variant) |
| Memory per device | O(n/C) where C = num devices |
| Steps | C (number of devices) |