Mamba2-JAX: Pure JAX Implementation of Mamba2
Introduction
This is an experimental JAX/Flax implementation of Mamba2 [1] inspired by vasqu’s exquisite PyTorch version [2]. The implementation provides a pure JAX alternative for researchers and practitioners who prefer the JAX ecosystem for its functional programming paradigm, automatic differentiation, and seamless integration with TPU hardware.
Current Status: Alpha (Stable) Release
This alpha version focuses on numerical correctness and stability. The implementation has been tested against the PyTorch version and shows equivalent numerical behavior see Numerical Validation below.
NOTE: This is an early-stage implementation that currently suppo…
Mamba2-JAX: Pure JAX Implementation of Mamba2
Introduction
This is an experimental JAX/Flax implementation of Mamba2 [1] inspired by vasqu’s exquisite PyTorch version [2]. The implementation provides a pure JAX alternative for researchers and practitioners who prefer the JAX ecosystem for its functional programming paradigm, automatic differentiation, and seamless integration with TPU hardware.
Current Status: Alpha (Stable) Release
This alpha version focuses on numerical correctness and stability. The implementation has been tested against the PyTorch version and shows equivalent numerical behavior see Numerical Validation below.
NOTE: This is an early-stage implementation that currently supports:
- Pure JAX/Flax implementation (no Triton kernels)
- Causal language modeling with
Mamba2ForCausalLM - Time series forecasting with
Mamba2Forecaster - Full forward and backward passes with gradient computation
- Small to medium-scale experimentation
Why JAX?
While vasqu’s excellent PyTorch implementation provides multiple optimization paths including Triton kernels, this JAX version offers several unique advantages:
- Functional Programming: JAX’s functional approach makes it easier to reason about model behavior and transformations
- Hardware Flexibility: Seamless support for TPUs alongside GPUs through XLA compilation
- Research-Friendly: JAX’s transformation system (jit, grad, vmap, pmap) enables elegant experimentation
- Ecosystem Integration: Natural fit for projects already using JAX (Flax, Optax, Haiku)
- Educational Value: Cleaner implementation for understanding Mamba2 internals without CUDA complexity
This implementation prioritizes clarity and correctness over raw performance, making it ideal for:
- Understanding Mamba2 architecture
- Rapid prototyping of variants
- Integration into JAX-based research codebases
- TPU-based training workflows
Installation
Stable Version
Automatically download from PyPI using pip
pip install mamba2-jax
Development Version
Clone the repository and install as a package:
git clone https://github.com/yourusername/mamba2-jax.git
cd mamba2-jax
pip install -e .
Requirements
For a simple CPU-only setup:
pip install jax jaxlib flax optax einops
For GPU (CUDA) or TPU support, install the appropriate JAX wheels for your hardware as described in the official JAX installation guide. Once JAX sees your device, mamba2-jax will automatically run there.
GPU (CUDA) & TPU support
Mamba2-JAX is a pure JAX/Flax library and runs on any backend supported by your JAX installation:
- CPU – default if you install standard
jax/jaxlib. - CUDA GPUs – supported today via JAX’s CUDA PJRT backend. This implementation has been smoke-tested on NVIDIA RTX 3500 Ada (laptop) and NVIDIA T4-class GPUs.
- TPUs – supported via JAX’s TPU backend. The library has been smoke-tested on Google Cloud TPU v5e-1.
For GPU and TPU setup, please follow the official JAX installation guide for device-specific wheels and instructions.
NOTE: GPU/TPU usage is still considered experimental in this alpha as it not has been extensively tested. The focus so far has been numerical correctness rather than deep performance tuning.
Usage
Basic Language Modeling Example
This complete example shows how to create a Mamba2 language model, initialize it, and run a forward pass. You can copy and paste this entire block to get started:
import jax
import jax.numpy as jnp
from mamba2_jax import Mamba2Config, Mamba2ForCausalLM
# Create a small configuration for testing
# You can scale these up for real applications
config = Mamba2Config(
vocab_size=1024, # Small vocabulary for demo
hidden_size=256, # Hidden dimension
state_size=64, # SSM state size
head_dim=32, # Dimension per head
num_hidden_layers=4, # Number of Mamba2 blocks
chunk_size=64, # Chunk size for SSD computation
)
# Initialize the model
model = Mamba2ForCausalLM(config)
# Create some random input tokens
key = jax.random.PRNGKey(42)
batch_size, seq_len = 2, 64
input_ids = jax.random.randint(
key,
(batch_size, seq_len),
minval=0,
maxval=config.vocab_size
)
# Initialize model parameters with the input shape
print("Initializing model parameters...")
variables = model.init(key, input_ids=input_ids)
params = variables["params"]
# Run forward pass with loss computation
print("Running forward pass...")
outputs = model.apply(
{"params": params},
input_ids=input_ids,
labels=input_ids, # Using same tokens as labels for demo
)
# Check outputs
print(f"Logits shape: {outputs['logits'].shape}") # Should be (2, 64, 1024)
print(f"Loss: {float(outputs['loss']):.4f}")
print("Forward pass completed successfully!")
Time Series Forecasting Example
This example shows how to use Mamba2 for time series prediction. The model takes a historical sequence and predicts future values:
import jax
import jax.numpy as jnp
import optax
from mamba2_jax import Mamba2Forecaster
# Suppose we have univariate timeseries windows of length L
batch_size = 8
input_length = 32
forecast_horizon = 12
input_dim = 1
output_dim = 1
model = Mamba2Forecaster(
input_dim=input_dim,
d_model=256,
n_layers=4,
output_dim=output_dim,
forecast_horizon=forecast_horizon,
)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (batch_size, input_length, input_dim))
variables = model.init(key, x)
params = variables["params"]
y_pred = model.apply({"params": params}, x) # (B, H, D_out)
print("Timeseries output shape:", y_pred.shape)
Advanced Features
The core Mamba2Model exposes the same SSM hooks as the PyTorch implementation:
- Stateful / streaming inference via
initial_statesandoutput_last_ssm_states. - Layer-wise analysis via
output_hidden_states=True.
See the runnable scripts in the examples/ directory:
04_streaming_states_demo.py– carry SSM state across chunks for streaming generation or very long sequences.05_inspect_hidden_states.py– retrieve per-layer hidden states for analysis or auxiliary losses.
Examples
For more end-to-end, runnable scripts (tiny training loops, sine-wave forecasting, streaming state demos, etc.), see the examples/ directory in this repository.
Numerical Validation with PyTorch
The implementation has been validated against the reference PyTorch version [2] to ensure numerical correctness on CPU. Further tests will investigate GPU (CUDA) and TPU performance once enabled post alpha release.
Methodology
-
A small Mamba2 model is instantiated in both PyTorch and JAX with identical hyperparameters (hidden size, state size, number of layers, sequence length, etc.).
-
A simple synthetic MSE regression task is constructed and shared between the two frameworks using the same random seed.
-
On CPU only, both models are trained side-by-side on this task for a short number of optimisation steps.
-
At each training step we record:
-
the PyTorch loss,
-
the JAX loss,
-
the absolute difference
|L_torch - L_jax|, -
and the per-step wall-clock time for each framework.
-
All experiments are run in
float32with no mixed precision or framework-specific numerical tricks, to keep the comparison as fair as possible.
The whole procedure is implemented in the standalone script test-parity.py, which can be run on any CPU-only machine.
The key experiment configuration is summarised below:
| Category | Parameter | PyTorch | JAX |
|---|---|---|---|
| Backend | Framework | PyTorch | JAX + Optax + Flax |
| Device | CPU | CPU | |
| Dtype | float32 | float32 | |
| Data / task | Task | Synthetic time-series forecasting (MSE regression) | Same dataset via shared NumPy arrays |
| Batch size | 2 | 2 | |
| Context length | 32 time steps | 32 time steps | |
| Input dimension | 10 | 10 | |
| Forecast horizon | 16 time steps | 16 time steps | |
| Model | Model wrapper | Mamba2Model + linear head | Mamba2Forecaster |
d_model | 768 | 768 | |
n_layers | 1 | 1 | |
d_state | 128 (Mamba2Config default) | 128 | |
headdim | 64 (Mamba2Config default) | 64 | |
expand | 2 | 2 | |
d_conv | 4 | 4 | |
| Training | Optimiser | Manual SGD (param -= lr * grad) | optax.sgd(lr) |
| Loss function | Mean squared error | Mean squared error | |
| Learning rate | 0.001 | 0.001 | |
| Training steps | 16 | 16 | |
| Random seed | Shared seed for data + initialisation where possible | Shared seed for data + initialisation where possible |
Test Results
See the Mamba2 PyTorch vs JAX parity test appendix for full details.
The figure below summarises the comparison:
- Training loss (left panel). The PyTorch and JAX MSE losses follow almost identical learning curves. Both decrease smoothly over time, and by the final steps the two curves are visually indistinguishable.
- Loss difference (middle panel). The absolute difference
|L_torch - L_jax|starts around ~2×10⁻¹ and decays monotonically during training, reaching the ~10⁻² range after ~15–16 steps. This level of discrepancy is well within normal numerical noise between different backends and confirms that the JAX implementation closely tracks the PyTorch reference. - CPU wall-clock time (right panel). On this micro-benchmark the JAX implementation is roughly 2× faster per step on CPU, with typical step times around ~0.08–0.09 s versus ~0.18–0.20 s for PyTorch.
Overall, these experiments indicate that the JAX implementation is numerically consistent with the PyTorch model while offering competitive (and often better) CPU performance for this class of workloads.
Summary metrics
| Category | Metric | PyTorch | JAX | Notes |
|---|---|---|---|---|
| Training loss | Initial MSE (step 0) | 1.5894 | 2.1382 | Different random inits, both converge |
| Training loss | Final MSE (step 15) | 0.0249 | 0.0371 | Final diff ≈ 0.0121 |
| Training loss | Mean abs. diff (all steps) | – | – | 0.1606 (mean), 0.5487 (max) |
| Training loss | Mean rel. diff (all steps) | – | – | ≈ 47 % mean, ≈ 51 % max |
| Prediction parity | Pearson correlation | – | – | 0.992 between PyTorch and JAX predictions |
| Prediction parity | MAE / std(torch) | – | – | ≈ 0.10 (~10 %) |
| Prediction parity | RMSE / std(torch) | – | – | ≈ 0.13 (~13 %) |
| Timing (CPU) | Mean step time | 0.1935 s | 0.0879 s | JAX is ≈ 2.2× faster per step on CPU |
| Timing (CPU) | JIT compile (train_step) | – | 0.97 s | One-off JIT cost before steady-state steps |
In short, both implementations learn very similar functions: the loss curves track each other closely, the final losses differ by only ≈ 0.012, and the final predictions have a Pearson correlation of ~0.99 with discrepancies on the order of 10–13 % of the PyTorch signal scale. On CPU, the JAX version achieves roughly a 2.2× lower per-step wall-clock time once the one-off JIT compilation cost is paid, while remaining numerically consistent with the PyTorch reference.
Project Structure
Roadmap
Beta Release (Coming Soon)
- GPU Optimisation: Profile and optimize performance on modern GPUs
- Expanded Test Suite: Comprehensive unit tests and integration tests
- Model Conversion Scripts: Tools to convert pretrained PyTorch weights to JAX
- Benchmarking Suite: Systematic performance comparison across hardware
- Documentation: Detailed API documentation and architecture guide
Future Releases
- Triton Kernel Support: Custom kernels for improved performance
- Pretrained Models: Host converted models on Hugging Face Hub
- Mixed Precision Training: BF16/FP16 support with proper loss scaling
- Model Parallelism: Support for large-scale training with pmap/pjit
- Advanced Caching: Efficient KV-like caching for generation
- Hybrid Variants: Attention and MLP hybrid architectures
Known Limitations
This alpha release has several known limitations:
- No Triton Kernels: Uses naive SSD implementation, slower than optimized PyTorch version
- No Pretrained Weights: No conversion scripts yet (coming in beta)
- Limited Generation Support: Basic generation only, no advanced sampling methods
- No Hybrid Architectures: Pure Mamba2 blocks only (no attention/MLP variants)
- Memory Usage: Not optimized for very long sequences (>4096 tokens)
We’re actively working on addressing these limitations in upcoming releases.
FAQ (JAX / CUDA / TPU)
Does mamba2-jax run on GPU and TPU?
Yes. As a pure JAX/Flax implementation, mamba2-jax runs on any backend that your JAX installation supports. If you install a CUDA-enabled JAX build, it will use your NVIDIA GPU; if you install the TPU wheels and run on Cloud TPU (e.g. v5e-1), it will run there too.
Why do I see messages like GPU interconnect information not available: NVML doesn't support extracting fabric info or NVLink is not used by the device.?
These lines are printed by JAX/XLA’s CUDA runtime during startup, not by mamba2-jax. They usually mean “your GPU does not expose NVLink / fabric topology to NVML”, and they are safe to ignore for normal training and inference.
What about Delay kernel timed out: measured time has sub-optimal accuracy... from cuda_timer.cc?
This message also comes from the XLA CUDA backend. It indicates that an internal timing kernel used for profiling and autotuning was not accurate enough and XLA fell back to a different timing path. It does not indicate that your model or gradients are wrong; it only affects how XLA measures performance internally.
How do I force CPU vs GPU?
JAX picks a backend automatically, but you can override it via the JAX_PLATFORMS environment variable before importing JAX:
# CPU only
export JAX_PLATFORMS=cpu
# Prefer CUDA, fall back to CPU if something is wrong
export JAX_PLATFORMS=cuda,cpu
How can I reduce JAX/XLA log noise?
JAX/XLA uses the TF_CPP_MIN_LOG_LEVEL environment variable to control C++ backend logging:
# Show INFO, WARNING, ERROR (default)
export TF_CPP_MIN_LOG_LEVEL=0
# Hide INFO + WARNING, keep ERROR (recommended if logs feel noisy)
export TF_CPP_MIN_LOG_LEVEL=2
You can set these in your shell or at the very top of your own scripts before importing JAX. The core mamba2-jax library does not change global logging settings for you.
Contributing
Contributions are welcome! Areas where help would be particularly valuable:
- Performance optimization and profiling
- Test coverage expansion
- Documentation improvements
- Bug reports and feature requests
Please open an issue or submit a pull request on GitHub.
Acknowledgments
This implementation builds upon the excellent work of many researchers and engineers:
Original Mamba2 Authors [1] :
- Tri Dao and Albert Gu for the Mamba2 architecture and original implementation
- The entire State Spaces team for advancing SSM research
PyTorch Implementation [2] :
- vasqu for the clean PyTorch implementation that served as a reference
- The implementation structure and many design decisions were inspired by mamba2-torch
- The JAX, Flax, and Optax teams at Google for the excellent frameworks
- The broader JAX community for tools and support
NOTE: I (Cosmo Santoni) am not affiliated with the original authors of Mamba2 paper nor PyTorch, HuggingFace, Google, JAX or Flax teams. I am an independent researcher at Imperial College London.s
References
[1] Mamba2
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}
[2] mamba2-torch (PyTorch Implementation)
@software{vasqu2024mamba2torch,
author = {vasqu},
title = {mamba2-torch: HuggingFace Compatible Mamba2},
year = {2024},
url = {https://github.com/vasqu/mamba2-torch}
}
[3] JAX
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.3.13},
year = {2018},
}
[4] Flax
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.7.0},
year = {2020},
}
License
MIT
Citation
If you use this implementation in your research, please cite both the original Mamba2 paper and acknowledge this JAX implementation:
@software{mamba2jax2024,
author = {[Cosmo Santoni]},
title = {mamba2-jax: Pure JAX Implementation of Mamba2},
year = {2024},
url = {https://github.com/CosmoNaught/mamba2-jax}
}
Questions or Issues? Please open an issue on GitHub or reach out through discussions.
Want to Contribute? PRs are welcome! See the Contributing section above.