Introduction
Causal masking (also called autoregressive masking or future masking) is a technique used in autoregressive models to prevent positions from attending to future positions during sequence generation. It ensures the model can only see context from the past and present, not from the future.
Why Causal Masking?
When generating a sequence autoregressively:
- Position 0 can only see position 0 (nothing before)
- Position 1 can see positions 0 and 1
- Position 2 can see positions 0, 1, and 2
- And so on...
The Causal Mask Matrix
For sequence length 4, causal mask M:
To position (j)
0 1 2 3
┌────┬────┬────┬────┐
0 │ 1 │ 0 │ 0 │ 0 │ From
├────┼────┼────┼────┤ position
1 │ 1 │ 1 │ 0 │ 0 │ (i)
├────┼────┼────┼────┤
2 │ 1 │ 1 │ 1 │ 0 │
├────┼────┼────┼────┤
3 │ 1 │ 1 │ 1 │ 1 │
└────┴────┴────┴────┘
1 = can attend, 0 = cannot attend
In practice, we use -∞ for forbidden positions
so softmax produces 0 attention weight.
Mathematical Formulation
Let M be the causal mask matrix
M[i,j] = 0 if j ≤ i (can attend)
M[i,j] = -∞ if j > i (cannot attend)
Attention with mask:
A = softmax(QKᵀ/√d + M)
Then A[i,j] = 0 for all j > i (because softmax(-∞) = 0)
M[i,j] = 0 if j ≤ i (can attend)
M[i,j] = -∞ if j > i (cannot attend)
Attention with mask:
A = softmax(QKᵀ/√d + M)
Then A[i,j] = 0 for all j > i (because softmax(-∞) = 0)
Implementation Methods
Method 1: Additive Mask
scores = QKᵀ/√d
scores_masked = scores + M (M has -∞ for future)
A = softmax(scores_masked, dim=-1)
scores_masked = scores + M (M has -∞ for future)
A = softmax(scores_masked, dim=-1)
Method 2: Lower Triangular
mask = torch.tril(torch.ones(seq_len, seq_len))
scores = scores.masked_fill(mask == 0, float('-inf'))
scores = scores.masked_fill(mask == 0, float('-inf'))
Method 3: Bias Addition (ALiBi style)
ALiBi bias already encodes causal structure
No additional masking needed
No additional masking needed
Causal in Training vs Generation
Training (Parallel)
All positions computed in parallel with causal masking applied to attention scores.
Generation (Autoregressive)
Generate one token at a time. When extending context, the causal mask naturally allows attending to all previous tokens.
Step 1: tokens[0:1] → predict token[1]
Step 2: tokens[0:2] → predict token[2]
Step 3: tokens[0:3] → predict token[3]
...
Step 2: tokens[0:2] → predict token[2]
Step 3: tokens[0:3] → predict token[3]
...
Causal Mask vs Padding Mask
| Aspect | Causal Mask | Padding Mask |
|---|---|---|
| Purpose | Block future positions | Block padding tokens |
| Shape | Lower triangular | Depends on padding pattern |
| Used in | Decoder (autoregressive) | Both encoder/decoder |
| Combined with | Often combined with padding mask | Often combined with causal |