Introduction
Attention masks are binary or continuous masks that control which positions in a sequence can attend to which other positions. They are essential for various attention scenarios including causal generation, ignoring padding, and implementing specialized attention patterns.
Types of Attention Masks
1. Causal Mask (Future Blocking)
Prevents attending to future positions
Lower triangular matrix of 0s and -∞s
Lower triangular matrix of 0s and -∞s
2. Padding Mask
Blocks attention to padding tokens
True/1 for padding, False/0 for real tokens
True/1 for padding, False/0 for real tokens
3. Combined Mask
mask = causal_mask OR padding_mask
For decoder: block future AND block padding
For decoder: block future AND block padding
Mask Application in Computation
Step 1: Compute attention scores
scores = QKᵀ / √d
Step 2: Apply mask (add -∞ for positions to block)
scores = scores + mask
Step 3: Softmax
attention_weights = softmax(scores, dim=-1)
Step 4: Weighted sum
output = attention_weights · V
scores = QKᵀ / √d
Step 2: Apply mask (add -∞ for positions to block)
scores = scores + mask
Step 3: Softmax
attention_weights = softmax(scores, dim=-1)
Step 4: Weighted sum
output = attention_weights · V
Mask Representations
| Mask Type | Value for Blocked | Value for Allowed | Used In |
|---|---|---|---|
| Causal | -∞ | 0 | Decoder self-attention |
| Padding | -∞ | 0 | Variable length sequences |
| BERT attention_mask | 0 | 1 | Bidirectional (ignoring) |
| Key padding | -∞ | 0 | Cross-attention with padding |
Advanced Masks
1. Chunk Mask
For local attention within chunks:
Chunk size C, sequence length N
Within each chunk: full attention
Between chunks: no attention
Within each chunk: full attention
Between chunks: no attention
2. Stride Mask
For attention with fixed gaps:
Allow attention at stride S
Block positions that are S apart
Block positions that are S apart
3. Arbitrary Attention Patterns
For sparse attention patterns:
Custom binary mask M
M[i,j] = 0 for allowed, -∞ for blocked
Or M[i,j] = 1 for allowed, 0 for blocked
M[i,j] = 0 for allowed, -∞ for blocked
Or M[i,j] = 1 for allowed, 0 for blocked
Efficiency Considerations
- Sparse masks: Can be stored efficiently as sparse matrices
- Dynamic masks: Can be generated at runtime for variable patterns
- Combined at operation: Multiple masks combined before softmax