Topic 68

State Space Models

Mamba, S4, and Structured State Spaces for Sequence Modeling

Overview

State Space Models (SSMs) represent a class of models that combine the benefits of recurrent networks and convolutional approaches for sequence modeling. Popularized by the Mamba and S4 architectures, SSMs offer an alternative to attention mechanisms with linear complexity in sequence length. Unlike attention mechanisms that compute pairwise interactions between all tokens, SSMs maintain a latent state that summarizes the sequence history, enabling efficient long-range dependency modeling without quadratic scaling.

The fundamental concept behind SSMs draws from classical state space models in control theory and signal processing, where a system is described by a continuous-time state equation and an observation equation. In the deep learning context, SSMs parameterize these equations and learn them from data, resulting in models that can process sequences efficiently while maintaining the ability to capture long-range dependencies.

Mathematical Foundation

At the core of SSMs is the continuous-time state space representation:

h'(t) = Ah(t) + Bx(t) [State equation] y(t) = Ch(t) + Dx(t) [Output equation]

Where x(t) is the input, h(t) is the hidden state, and y(t) is the output. A, B, C, and D are learned parameter matrices. To discretize for neural networks, we use techniques like the bilinear method:

h_t = Ah_{t-1} + Bx_t y_t = Ch_t

The key innovation in structured state space models (S4) is the choice of the A matrix. By initializing A with a specific structured form (diagonal plus low-rank), S4 enables efficient computation through diagonalization. This approach achieves O(N) complexity for both training (via convolution) and inference (via recurrence), compared to O(N²) for standard attention.

Mamba: Selective State Spaces

Mamba extends the S4 architecture by introducing selection mechanisms that allow the model to decide which information to propagate or discard based on the input. This is achieved by making the B, C, and Δ parameters input-dependent rather than fixed:

B_t = Linear(x_t) C_t = Linear(x_t) Δ_t = τ(Linear(x_t))

This selection mechanism allows Mamba to achieve performance comparable to transformers on long sequences while maintaining linear complexity. The selection scan algorithm enables efficient parallel computation of the selective state space operations.

Key Properties

Linear Complexity

Unlike the quadratic scaling of attention mechanisms, SSMs scale linearly with sequence length. This makes them particularly attractive for processing very long sequences where attention becomes computationally prohibitive. The trade-off is that SSMs may not capture certain types of dependencies as effectively as full attention.

Efficient Inference

SSMs support both parallel training (using a convolution kernel derived from the recurrence) and efficient autoregressive inference. The hidden state at each step depends only on the previous state and current input, enabling streaming generation without recomputing the entire context.

Interpretability

The state space formulation provides an interpretable latent representation. The A matrix controls how information decays or persists over time, and different initialization strategies can impart different inductive biases such as smoothing or persistence.

Comparison with Attention

Architecture Variants

S4 (Structured State Spaces)

The S4 architecture introduced structured matrices for the state transition, enabling stable and efficient computation. The key innovation is the diagonal + nilpotent structure of the A matrix, which allows the model to be computed as a linear convolution. S4 excels at long-range dependency tasks and has achieved state-of-the-art results on tasks like Long Range Arena.

Mamba (SSSM)

Mamba, or State Space Sequence Model (SSSM), introduces input-dependent selection mechanisms. By making the B, C matrices and the timestep Δ dependent on the input, Mamba can dynamically route information. This addresses a limitation of S4 where the state transition is fixed after initialization.

Hyena

Hyena combines implicit attention with SSM principles, using a hierarchy of convolutions with data-dependent filtering. It achieves sub-quadratic scaling while maintaining the ability to learn arbitrary dependencies through learned filter coefficients.

RWKV

Receptance Weighted Key Value (RWKV) blends transformer efficiency with RNN-style inference. It uses a novel linear attention mechanism with position-encoded decay, achieving transformer-like performance with RNN-like inference costs.

Applications

State Space Models have been successfully applied across various domains:

Implementation Considerations

When implementing SSMs, consider these architectural choices:

import torch
import torch.nn as nn

class SimplifiedSSM(nn.Module):
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # Projections
        self.x_proj = nn.Linear(d_model, d_state * 2 + 1)
        self.dt_proj = nn.Linear(1, d_model)
        self.out_proj = nn.Linear(d_state, d_model)
        
        # Initialize A matrix (diagonal)
        self.A = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x):
        B, L, D = x.shape
        
        # Compute SSM parameters from input
        x_dbl = self.x_proj(x)
        dt, B_mat, C_mat = x_dbl.split([1, self.d_state, self.d_state], dim=-1)
        
        # Discretize
        dt = self.dt_proj(dt)
        dt = F.softplus(dt)
        
        # State update (simplified)
        # In practice, use selective scan or convolution-based computation
        h = torch.tanh(F.linear(x, self.A) + self.D * x)
        y = F.linear(h, C_mat.transpose(-1, -2))
        
        return y + x  # Residual connection

Discretization Methods

The choice of discretization method affects the model's behavior:

Test Your Understanding

Q1: What is the key computational advantage of SSMs over attention mechanisms?

Answer: SSMs achieve O(N) linear complexity in sequence length, compared to O(N²) quadratic complexity for standard attention mechanisms. This is because SSMs maintain a constant-size hidden state and compute transitions efficiently through convolution (training) or recurrence (inference), avoiding the pairwise token comparisons that make attention expensive.

Q2: How does Mamba improve upon the basic S4 architecture?

Answer: Mamba introduces selection mechanisms by making the B, C matrices and timestep Δ input-dependent rather than fixed parameters. This allows the model to dynamically choose which information to keep or discard at each step, similar to how attention can focus on relevant context. S4 uses fixed state transitions, while Mamba's selective state spaces can adapt their behavior to the specific input sequence.

Q3: Why might SSMs be preferred over attention for very long sequences?

Answer: For very long sequences, attention's O(N²) memory and computational requirements become prohibitive, while SSMs maintain constant O(N) memory and linear O(N) computation. Additionally, SSMs only need to store a fixed-size hidden state (O(D)) during inference compared to attention's O(N) key-value cache. SSMs also enable efficient streaming generation where each token depends only on the previous state.

Q4: What is the role of the A matrix in an SSM?

Answer: The A matrix controls the state transition dynamics, determining how information flows and decays through the sequence. It governs the rate at which information from past timesteps persists or vanishes in the hidden state. Structured SSMs like S4 initialize A with a specific diagonal-plus-low-rank form that enables efficient computation while maintaining the expressive power to capture long-range dependencies.