Mamba and State Space Model Architectures: The Alternative to Transformers?
Since 2017, Transformers have dominated the artificial intelligence landscape, from Large Language Models (LLMs) to computer vision. However, a new family of architectures is starting to seriously eat away at their monopoly: State Space Models (SSM), and particularly the Mamba model. Promising linear complexity compared to the quadratic complexity of Transformers, these models could well solve the context window bottleneck. In this article, we will break down how SSMs work, understand Mamba's innovations, analyze the benchmarks, and see how to implement these models in practice.
Prerequisites
- Fundamental concepts in Deep Learning (neural networks, backpropagation)
- Basic understanding of the Transformer architecture (attention mechanism, self-attention)
- Knowledge of linear algebra (matrix multiplication, state vectors)
- Basics of Python and PyTorch
- Concepts of algorithmic complexity (Big O notation)
The Problem with Transformers: The Cost of Attention
To understand the rise of SSMs, we first need to understand the current limitations of Transformers. The core of a Transformer is the attention mechanism, which allows the model to relate each token in a sequence to all other tokens. Mathematically, for a sequence of length $L$ and a model dimension $d$, the attention matrix has a size of $L \times L$.
This leads to a time and space complexity of $O(L^2 \cdot d)$. In other words, if you double the size of your context (for example, going from 32k to 64k tokens), the computation time and required memory are multiplied by four. This is known as quadratic complexity.
# Illustration conceptuelle de la complexité quadratique vs linéaire
import numpy as np
import matplotlib.pyplot as plt
L = np.linspace(1000, 100000, 100)
# Complexite quadratique (Transformers)
transformer_complexity = L**2
# Complexite lineaire (SSM / Mamba)
ssm_complexity = L
# Normalisation pour comparaison visuelle
transformer_norm = transformer_complexity / transformer_complexity[0]
ssm_norm = ssm_complexity / ssm_complexity[0]
plt.figure(figsize=(10, 6))
plt.plot(L, transformer_norm, label='Transformers (O(L²))', color='red')
plt.plot(L, ssm_norm, label='State Space Models (O(L))', color='blue')
plt.yscale('log')
plt.xlabel('Longueur de la séquence (L)')
plt.ylabel('Complexité (normalisée)')
plt.title('Complexité algorithmique : Transformers vs SSM')
plt.legend()
plt.grid(True, which="both", ls="--")
# plt.show() (Décommentez pour exécuter dans un notebook)
The Context Window Bottleneck
This quadratic limitation forces LLM creators to truncate long texts. Analyzing an entire book, a complete GitHub repository, or long videos requires heavy workarounds (like context compression or hierarchical architectures) that degrade the quality of the responses. We need a mechanism capable of processing infinite sequences with a fixed cost per token. This is exactly what State Space Models propose.
Introduction to State Space Models (SSM)
State Space Models draw their roots from control theory, a field of engineering and mathematics used to model continuous physical systems (like the flight of an airplane or the electrical circuit of a robot). The idea is to model a system by observing how an input signal $u(t)$ modifies a hidden internal state $x(t)$ to produce an output signal $y(t)$.
Continuous State Equations
A continuous SSM is defined by two linear differential equations:
$$x'(t) = Ax(t) + Bu(t)$$
$$y(t) = Cx(t) + Du(t)$$
Where:
- $x(t)$ is the latent state vector (the "memory" of the system)
- $u(t)$ is the input signal (for example, a continuous stream of text)
- $y(t)$ is the output signal
- $A$, $B$, $C$, $D$ are learnable parameter matrices
- $A$ is particularly important: it is the State Transition Matrix. It determines how the state naturally evolves over time.
Discretization
The problem is that in computer science, we work with discrete data (text tokens, pixels). We must convert these continuous equations into discrete equations by a time step $\Delta$. The foundational S4 paper (Structured State Spaces for Sequence Modeling) introduced an elegant discretization method based on the bilinear transform of the Zero-Order Hold (ZOH) approximation:
$$\bar{A} = (I - \Delta A)^{-1}$$
$$\bar{B} = (I - \Delta A)^{-1} \Delta B$$
This gives us the discrete recurrence:
$$x_k = \bar{A}x_{k-1} + \bar{B}u_k$$
$$y_k = Cx_k$$
Why is this Powerful?
This recursive equation $x_k = \bar{A}x_{k-1} + \bar{B}u_k$ has a magical property in computer science: it is performed in constant time $O(1)$ per token. Regardless of the length of the sequence, to calculate the current state, you only need the previous state and the current token. The overall complexity therefore drops to $O(L)$! Furthermore, thanks to the convolution property, these operations can be parallelized on GPUs during training, while offering ultra-fast recursive decoding during inference.
The Selective Content Problem: The Arrival of Mamba
If SSMs are so powerful, why didn't they replace Transformers earlier? The major problem with classical SSMs (like S4) lies in their inability to do selective content (selective copy/paste).
In a text, some words are filler ("the", "and", "a"), while others carry critical information ("API key", "password", a proper noun). The Transformers' attention mechanism excels at this: it can decide to send 100% of its attention to a specific token, even if it is located 10,000 tokens away.
Classical SSMs, on the other hand, process each token with the same parameters $\bar{A}$, $\bar{B}$, $C$. The matrix $\bar{A}$ is generally initialized to be a HiPPO (High-order Polynomial Projection Operators) type matrix, designed to retain the global history of the sequence like a polynomial. This is excellent for summarizing, but terrible for ignoring noise and suddenly focusing on a specific detail.
Mamba's Innovation: Input-Dependent Parameters
This is where Albert Gu and Tri Dao (creator of FlashAttention) step in with their foundational paper Mamba: Linear-Time Sequence Modeling with Selective State Spaces (arxiv.org/abs/2312.00752). Their idea is to make the SSM parameters dynamic: they depend on the input $u_k$.
- $B_k, C_k = \text{Linear}_{BC}(u_k)$
- $\Delta_k = \text{Softplus}(\text{Linear}_{\Delta}(u_k))$ (The time step $\Delta$ also becomes dynamic!)
If the model recognizes an important token, it can modify $\Delta_k$ to "slow down" time, allowing the token to imprint itself deeply into the state matrix $x_k$. If it's noise, it "speeds up" time, and the token is almost instantly forgotten.
# Pseudo-code PyTorch illustrant la dynamique de Mamba
import torch
import torch.nn as nn
class MambaDynamicParams(nn.Module):
def __init__(self, d_model, d_state):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# Projette l'entree vers B, C et le pas de temps Delta
self.proj_BCD = nn.Linear(d_model, 2 * d_state + d_model)
self.proj_Delta = nn.Linear(d_model, d_model)
self.softplus = nn.Softplus()
def forward(self, u_k):
# u_k : (Batch, Length, d_model)
BCD = self.proj_BCD(u_k)
B_k = BCD[..., :self.d_state] # Dependant de l'entree !
C_k = BCD[..., self.d_state:2*self.d_state] # Dependant de l'entree !
# Delta_k est positif grace au Softplus et depend de l'entree
Delta_k = self.softplus(self.proj_Delta(u_k))
return B_k, C_k, Delta_k
The Hardware Lock (Hardware-aware algorithm)
Making $B$, $C$, and $\Delta$ dynamic breaks the convolution property. We can no longer use the FFT (Fast Fourier Transform) to parallelize the computation on the GPU. We are forced to return to the sequential recursive loop, which is extremely slow on modern GPUs designed for massively parallel computation.
This is Mamba's second brilliant innovation. The authors designed a hardware-specific algorithm (inspired by FlashAttention):
1. Instead of storing the state in global memory (VRAM), they store it in the GPU's ultra-fast, highly localized memory (SRAM).
2. They use a "scan" mechanism that performs the recurrence in a parallel block-by-block manner (associative scan), exploiting the warp-level primitives of GPUs.
This allows Mamba to maintain $O(L)$ time complexity while achieving training speeds comparable to, or even exceeding, those of Transformers.
Detailed Architecture of Mamba
A complete Mamba block functions as an intelligent stacking of different layers designed to process the sequence hierarchically. Here is how a token passes through a Mamba block:
1. Dimension Expansion
The input sequence of dimension $d_{model}$ (for example 768) is first projected into a higher-dimensional space $E \times d_{model}$ (often $E=2$) via a convolution with a kernel size of $K$ (generally $K=4$). This convolution acts as a local pre-filter, allowing the model to look at the immediate previous $K$ tokens, which the global recurrence does not do natively.
2. Dynamic Selection
This is where the linear projections we saw above come into play to generate $B_k$, $C_k$, and the time step $\Delta_k$. This is the heart of the "Selective State Space".
3. The SSM Scan
The hardware-aware scan algorithm applies the discrete recurrence using the dynamic parameters to update the hidden state $x_k$.
4. Output Projection and Residual Connection
The output of the scan is projected from the expanded space back to the original $d_{model}$ space. A residual connection (skip connection) is added, similar to what is done in Transformers, combined with normalization.
# Architecture simplifiée d'un bloc Mamba en PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = self.expand * self.d_model
# 1. Convolution causale locale
self.in_proj = nn.Linear(d_model, self.d_inner * 2)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=True,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1
)
# 2. Paramètres SSM (simplifiés pour la démonstration)
self.x_proj = nn.Linear(self.d_inner, d_state * 2)
self.dt_proj = nn.Linear(self.d_inner, self.d_inner)
# Matrice A initialisée avec la structure HiPPO
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
# 4. Projection de sortie
self.out_proj = nn.Linear(self.d_inner, d_model)
def forward(self, x):
"""
x : (B, L, D)
"""
batch, seqlen, dim = x.shape
# Projection et convolution
xz = self.in_proj(x) # (B, L, 2 * d_inner)
x, z = xz.chunk(2, dim=-1) # Séparation en branche SSM et branche de porte (gate)
# Convolution causale
x = x.transpose(1, 2) # (B, d_inner, L)
x = self.conv1d(x)[:, :, :seqlen] # Garde seulement la partie causale
x = x.transpose(1, 2) # (B, L, d_inner)
# Activation
x = F.silu(x)
# Paramètres dynamiques dépendants de l'entrée
B = self.x_proj(x)[:, :, :self.d_state]
C = self.x_proj(x)[:, :, self.d_state:]
Delta = F.softplus(self.dt_proj(x))
# A partir d'ici, l'implémentation réelle utilise
# l'opération custom `selective_scan_fn` en CUDA/C++
# pour des raisons de performance matérielle.
# Voici l'équivalent Numpy conceptuel :
# A = -torch.exp(self.A_log) # (d_inner, d_state)
# y = selective_scan(x, Delta, A, B, C, self.D)
# Pour cet exemple, nous utilisons une approximation simplifiée
y = self.simplified_ssm_scan(x, Delta, B, C)
# Gating et sortie
y = y * F.silu(z)
output = self.out_proj(y)
return output
def simplified_ssm_scan(self, x, Delta, B, C):
"""Simulation séquentielle lente (à des fins pédagogiques uniquement)"""
batch, seqlen, d_inner = x.shape
d_state = B.shape[-1]
A = -torch.exp(self.A_log) # Négatif pour la stabilité
y = torch.zeros_like(x)
hs = torch.zeros(batch, d_inner, d_state, device=x.device)
for i in range(seqlen):
# Discrétisation simplifiée
dt = Delta[:, i, :].unsqueeze(-1) # (B, d_inner, 1)
dA = torch.exp(A.unsqueeze(0) * dt) # (B, d_inner, d_state)
dB = B[:, i, :].unsqueeze(1) * dt # (B, d_inner, d_state)
# Mise à jour de l'état : x_k = A * x_{k-1} + B * u_k
hs = dA * hs + dB * x[:, i, :].unsqueeze(-1)
# Sortie : y_k = C * x_k
y[:, i, :] = (hs * C[:, i, :].unsqueeze(1)).sum(dim=-1) + self.D * x[:, i, :]
return y
Mamba-2: Structured State Space Duality
If Mamba marked a milestone in 2023, the release of Mamba-2 in 2024 consolidated this architecture by giving it even stronger mathematical foundations. Tri Dao and Albert Gu proved in their new paper that SSMs and Attention actually share a common dual structure based on structured linear algebra.
The shift to head-wise space
Mamba-2 modifies how the state matrix $A$ and the parameters $B, C$ are structured. Instead of treating the latent space as a large monolithic block, Mamba-2 divides it into "heads", exactly like the Multi-Head Attention of Transformers.
The recurrence becomes:
$$x_k^h = A^h x_{k-1}^h + B_k^h u_k^h$$
The SDL (Structured Dot Product) algorithm
The major innovation of Mamba-2 is the replacement of the associative scan with an algorithm called Structured Dot Product (SDL). Instead of performing a multiplication by a full state matrix, Mamba-2 exploits the structure of the matrix $A$ (often structured as a normalized Toeplitz or Hankel matrix) to transform the recurrence into a structured dot product.
This allows Mamba-2 to use the same memory optimization kernels (tiling, flash-like memory IO) as FlashAttention-2, resulting in a 2 to 8 times training speedup compared to Mamba-1, while allowing the management of much larger state dimensions.
Results and Benchmarks
The theory is compelling, but what about actual performance? This is where Mamba justifies all the hype.
Hardware Efficiency (Throughput)
This is the area where Mamba literally crushes Transformers. Thanks to its $O(L)$ complexity and its hardware-optimized algorithm, Mamba does not suffer from memory explosion when increasing the context window.
- In inference (token generation): Mamba is 4 to 5 times faster than an equivalent Transformer for long sequences, because it does not need to store and recompute the KV (Key-Value) cache for all previous tokens.
- In training: Mamba achieves significantly higher throughputs (tokens per second per GPU) when the sequence length exceeds {{seuil_performance_mamba | 8192}} tokens.
Performance on language modeling tasks
On classic NLP benchmarks (Pile, WikiText, LAMBADA), Mamba-type models (like the family of open models provided by the authors) directly rival Transformers of the same size.
- Small to medium-sized models (< 1B parameters): Mamba surpasses or matches Transformers (like S4 or small Llama) in perplexity.
- Large-scale models (> 2.5B parameters): This is the only point where Transformers (like Llama or Mistral) keep a slight advantage on pure reasoning tasks. Global attention remains slightly more powerful for solving complex logical problems requiring multiple jumps across the text.
Triumph over "Selective Copying"
Mamba was specifically designed to solve tasks that previous SSMs failed to solve. On synthetic selective copying tasks (for example, "Find the API key hidden in this 50,000-word text and copy it"), Mamba achieves a 100% success rate, whereas classic SSMs drop to 0% and Transformers suffer from a severe slowdown.
Hybrids: the best of both worlds?
The current trend in the industry (seen at Microsoft, Mistral, or in architectures like Jamba) is not to kill Transformers, but to hybridize them.
The Jamba architecture (from AI21 Labs) stacks Mamba layers with Attention layers. For example, a model can use 8 Mamba layers for 1 Attention layer. Mamba absorbs 95% of the long context and noise ultra-fast, and Attention is used occasionally to make complex connections and deep reasoning. This type of hybrid offers an exceptional quality/speed ratio and a context window reaching {{fenetre_contexte_hybride | 256000}} tokens with a VRAM footprint divided by 4.
How to use Mamba in practice
The ecosystem around Mamba is growing rapidly. The easiest way to experiment with it is to use the official mamba-ssm library, developed largely by Tri Dao.
Installation
Installation requires local CUDA compilation, because Mamba's core is written in C++/Triton for performance reasons.
# It is recommended to use a virtual environment
conda create -n mamba_env python=3.10 -y
conda activate mamba_env
# Install PyTorch with CUDA {{version_cuda | 12.1}}
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu{{version_cuda_short | 121}}
# Install causal-conv1d (mandatory dependency for the convolution operation)
pip install causal-conv1d
# Install mamba-ssm
pip install mamba-ssm
Basic usage with Hugging Face
Many pre-trained models based on Mamba are now available on the Hugging Face Hub. You can use them exactly like a classic Transformer model thanks to the transformers library.
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_with_mamba(prompt):
# Loading a pre-trained Mamba model (ex: state-spaces/mamba-2.8b)
model_id = "state-spaces/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Tokenization
inputs = tokenizer(prompt, return_tensors="pt")
# Generation
outputs = model.generate(
**inputs,
max_new_tokens={{max_tokens_generation | 200}},
temperature={{temperature_generation | 0.7}},
do_sample=True
)
# Decoding
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Example of use
texte = "L'intelligence artificielle est en train de révolutionner"
print(generate_with_mamba(texte))
Implementing a custom Mamba layer
If you are building your own architecture from scratch (for example for a time series or genomic classification model, areas where Mamba excels), you will interact directly with the base operation.
import torch
import torch.nn as nn
from mamba_ssm import Mamba
class MonClassifieurSequence(nn.Module):
def __init__(self, d_model=256, n_layers=4, n_classes=5):
super().__init__()
self.layers = nn.ModuleList([
Mamba(
d_model=d_model, # Model dimension
d_state=16, # SSM state dimension (N in the equations)
d_conv=4, # Local convolution kernel size
expand=2, # Expansion factor of Mamba's inner projection
) for _ in range(n_layers)
])
self.norm_f = nn.RMSNorm(d_model)
self.classifier = nn.Linear(d_model, n_classes)
def forward(self, x):
"""
x : Tensor of shape (Batch, Length, d_model)
"""
# Pass through the Mamba blocks
for layer in self.layers:
x = layer(x) + x # Residual connection
x = self.norm_f(x)
# Pooling strategy: we take the last token of the sequence
# (equivalent to the [CLS] token or the last hidden state)
last_token = x[:, -1, :]
logits = self.classifier(last_token)
return logits
# Model initialization
modele = MonClassifieurSequence(d_model={{dim_modele_custom | 512}}, n_layers={{nb_couches_custom | 6}})
# Simulation of a batch of sequences
batch_size = 8
seq_len = 4096 # Mamba handles this without any issue
dummy_input = torch.randn(batch_size, seq_len, {{dim_modele_custom | 512}})
# Forward pass
sorties = modele(dummy_input)
print(f"Forme de sortie : {sorties.shape}") # Expected: (8, 5)
Summary
- Algorithmic complexity: SSMs like Mamba operate in $O(L)$ (linear) compared to the $O(L^2)$ (quadratic) of Transformers, allowing the processing of quasi-infinite sequences without VRAM saturation.
- Selective mechanism: Mamba solves the historical flaw of SSMs by making its parameters ($B$, $C$, $\Delta$) dynamic and dependent on the input token, enabling precise filtering of information.
- Hardware optimization: Thanks to algorithms inspired by FlashAttention (associative scan in Mamba-1, SDL in Mamba-2), the mathematical theory translates into real speedups on GPUs.
- Inference efficiency: Mamba eliminates the need for an expensive memory KV (Key-Value) cache, making text generation much faster and more economical.
- Performance: On language and selective copying benchmarks, Mamba rivals or surpasses Transformers of equal size, although very large Transformers keep a slight advantage on pure reasoning.
- Hybrid Trend: The likely future is not a total replacement, but a hybridization (ex: Jamba) combining Mamba layers for long context and occasional Attention layers for reasoning.
Conclusion
Mamba and State Space Model architectures represent the first true paradigmatic alternative to Transformers since 2017. By tackling the fundamental problem of quadratic complexity while solving the selective content problem, Mamba redefines the limits of what we can do with sequence processing. If current very large LLMs continue to dominate general reasoning benchmarks, the raw efficiency, the reduction of inference costs, and the ability to ingest cont
Gigantic texts make SSMs the go-to tool for the coming years, particularly for multimodal architectures (where video and audio require colossal context windows).
Ready to take action and integrate these new architectures into your workflows? Join the AI-Master.dev community to access our hands-on workshops, download optimized notebooks, and master tomorrow's models.
[Subscribe to the AI-Master.dev newsletter]