Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
> Tri Dao*, Albert Gu*\
> Paper: https://arxiv.org/abs/2405.21060

> **Mamba-3: Improved Sequence Modeling with Structured State Spaces**\
> Paper: https://openreview.net/pdf?id=HwCvaJOiCj

## About

Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
Expand Down Expand Up @@ -95,6 +98,28 @@ assert y.shape == x.shape
A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions
is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).

### Mamba-3

The Mamba-3 block is implemented at [modules/mamba3.py](mamba_ssm/modules/mamba3.py).

A simpler version is at [modules/mamba3_simple.py](mamba_ssm/modules/mamba3_simple.py)

Usage:
``` python
from mamba_ssm import Mamba3
model = Mamba3(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
```

Mamba-3 adds RoPE, BCNorm, and MIMO (multi-input multi-output) support on top of the SSD framework.
To use Mamba-3 in a full language model, set `"layer": "Mamba3"` in `ssm_cfg`.

### Mamba Language Model

Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
Expand Down Expand Up @@ -240,4 +265,10 @@ If you use this codebase, or otherwise find our work valuable, please cite Mamba
year={2024}
}

@inproceedings{mamba3,
title={Mamba-3: Improved Sequence Modeling with Structured State Spaces},
booktitle={International Conference on Learning Representations (ICLR)},
year={2026}
}

```
1 change: 1 addition & 0 deletions mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mamba3 import Mamba3
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
8 changes: 5 additions & 3 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mamba3 import Mamba3
from mamba_ssm.modules.mha import MHA
from mamba_ssm.modules.mlp import GatedMLP
from mamba_ssm.modules.block import Block
Expand Down Expand Up @@ -51,10 +52,11 @@ def create_block(
# Create a copy of the config to modify
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
if ssm_layer not in ["Mamba1", "Mamba2"]:
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
if ssm_layer not in ["Mamba1", "Mamba2", "Mamba3"]:
raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1, Mamba2, and Mamba3")
layer_cls = {"Mamba1": Mamba, "Mamba2": Mamba2, "Mamba3": Mamba3}[ssm_layer]
mixer_cls = partial(
Mamba2 if ssm_layer == "Mamba2" else Mamba,
layer_cls,
layer_idx=layer_idx,
**ssm_cfg,
**factory_kwargs
Expand Down
Loading