20. Masked Self-Attention

Introduction

Masked self-attention is a variant of self-attention where future positions are masked to prevent attending to them. This is essential for autoregressive models (like decoders) that generate sequences token by token, ensuring each position can only see previous positions.

Why Mask?

In autoregressive generation, we cannot use future information because it hasn't been generated yet. Masking enforces causal ordering:

The Mask Matrix

For sequence length 4, causal mask M: To position (j) 0 1 2 3 ┌────┬────┬────┬────┐ 0 │ 0 │-∞ │-∞ │-∞ │ From ├────┼────┼────┼────┤ position 1 │ 0 │ 0 │-∞ │-∞ │ (i) ├────┼────┼────┼────┤ 2 │ 0 │ 0 │ 0 │-∞ │ ├────┼────┼────┼────┤ 3 │ 0 │ 0 │ 0 │ 0 │ └────┴────┴────┴────┘ 0 = can attend, -∞ = blocked

Implementation

Step 1: Compute attention scores
E = QKᵀ / √d

Step 2: Apply causal mask
E_masked = E + M
where M[i,j] = 0 if j ≤ i, else -∞

Step 3: Softmax
A = softmax(E_masked, axis=-1)

Step 4: Weighted sum
Output = AV

Efficient Implementation: Lower Triangular Mask

Instead of creating a full n×n matrix, we can use:

mask = torch.tril(torch.ones(seq_len, seq_len))

# Before softmax, add to attention scores: scores = scores.masked_fill(mask == 0, float('-inf'))

Types of Masks

1. Causal Mask (Upper Triangular)

Prevents attending to future positions. Standard for autoregressive models.

2. Lookahead Mask

Alternative representation: M[i,j] = 0 if j < i, else -∞

3. Padding Mask

Masks padding tokens in variable-length sequences:

M_padding[i, pad_j] = -∞ if position j is padding

Combining Masks

Combined mask = causal_mask OR padding_mask

Final attention = softmax(QKᵀ/√d + combined_mask)

During Training vs Inference

Training

Use teacher forcing: input is the full target sequence, all positions computed in parallel with masking.

Inference

Autoregressive: generate one token at a time, extend context window each step.

Test Your Understanding

Question 1: What does masked self-attention prevent?

  • A) Attending to padding
  • B) Attending to future positions
  • C) Self-attention
  • D) Using queries

Question 2: What value is used to mask future positions in the attention score?

  • A) 0
  • B) 1
  • C) -∞
  • D) ∞

Question 3: For position i in masked attention, which positions can it attend to?

  • A) Only position i
  • B) All positions
  • C) Positions 0 to i only
  • D) Positions i to n

Question 4: The mask matrix for sequence length 4 has what shape?

  • A) [4]
  • B) [4, 4]
  • C) [4, d]
  • D) [d, d]

Question 5: What is the mask shape for lower triangular (causal) mask?

  • A) Upper triangle is 0, lower is -∞
  • B) Upper triangle is -∞, lower is 0
  • C) All zeros
  • D) All -∞

Question 6: During training, masked self-attention:

  • A) Is applied sequentially token by token
  • B) Computes all positions in parallel with masking
  • C) Cannot be used
  • D) Uses future information

Question 7: What is teacher forcing?

  • A) Using model predictions as input
  • B) Using actual previous tokens as input during training
  • C) Forcing the model to learn
  • D) Masking all positions

Question 8: Why do we add -∞ instead of 0 to mask positions?

  • A) To make computation faster
  • B) Because softmax(-∞) = 0, effectively blocking attention
  • C) To increase the value
  • D) Because softmax(0) = 0