feat: Add Mamba-3 implementation (ICLR 2026)#850
Open
payk24 wants to merge 3 commits into
Open
Conversation
Mamba-3 module with chunked parallel SSD, RoPE, BCNorm, SISO/MIMO support, seq_idx for packed multi-document training, Triton decode kernel, and gradient checkpointing. Includes full and simplified module variants, Triton ops, and comprehensive CPU/GPU test suites (74 tests).
Triton-accelerated forward and backward for Mamba-3 chunked SSD: New Triton kernels (5 new files): - mamba3_chunk_scan.py: Forward scan kernel with trapezoidal support - mamba3_chunk_scan_bwd.py: Backward kernels for dx, dCB, ddA - mamba3_chunk_state.py: Forward chunk state with gamma/beta scaling - mamba3_chunk_state_bwd.py: Backward kernels for dB and ddA (stable) - mamba3_combined.py: Autograd Function orchestrating Triton pipeline Supporting kernels: - mamba3_rope.py: RoPE application for B/C tensors - mamba3_shift.py: Shifted B/x computation for trapezoidal lookback Key design decisions: - Reuses 6 Mamba-2 backward kernels (dz, dstates, state_passing, dC, cumsum, bmm) — only Mamba-3-specific ops get new kernels - SISO Triton path with PyTorch fallback for MIMO - initial_prev_Bx backward: explicit gradient through state + output corrections (additive, so existing kernels handle other params) - RoPE backward via PyTorch autograd recompute (general, correct) Bug fixes: - ssd_state_passing.py: int32/int64 loop variable mismatch with seq_idx - mamba3_ssd.py: seq_idx propagation in chunked path - mamba3.py/mamba3_simple.py: D parameter bf16 dtype handling Tests: 64 Triton backward + 2 Triton forward tests (all passing)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements the Mamba-3 architecture as described in the ICLR 2026 paper:
https://openreview.net/pdf?id=HwCvaJOiCj
Modules
mamba_ssm/modules/mamba3.py— Full Mamba-3 module with tensor parallelism, inference cache, SISO/MIMO supportmamba_ssm/modules/mamba3_simple.py— Simplified variant (no TP/inference params)mamba_ssm/ops/triton/mamba3_ssd.py— Chunked parallel SSD kernels + Triton decode kernel (SISO & MIMO)mamba_ssm/__init__.py/mamba_ssm/models/mixer_seq_simple.py— Integration ("layer": "Mamba3"in ssm_cfg)Triton training kernels (forward + backward)
mamba_ssm/ops/triton/mamba3_combined.py— Autograd Function orchestrating Triton forward/backward pipelinemamba_ssm/ops/triton/mamba3_chunk_scan.py— Forward scan kernel with trapezoidal discretizationmamba_ssm/ops/triton/mamba3_chunk_scan_bwd.py— Backward kernels: dx, dCB, ddAmamba_ssm/ops/triton/mamba3_chunk_state.py— Forward chunk state with gamma/beta scalingmamba_ssm/ops/triton/mamba3_chunk_state_bwd.py— Backward kernels: dB, ddA (numerically stable)mamba_ssm/ops/triton/mamba3_rope.py— RoPE application for B/C tensorsmamba_ssm/ops/triton/mamba3_shift.py— Shifted B/x for trapezoidal lookback termArchitecture: reuses 6 Mamba-2 backward kernels (dz, dstates, state_passing, dC, cumsum, bmm) — only Mamba-3-specific ops get new Triton kernels.
Key features
seq_idxsupport for packed multi-document traininginitial_prev_Bxbackward through state + output correctionsuse_mem_eff_path)Bug fixes
ssd_state_passing.py: int32/int64 loop variable mismatch with seq_idxmamba3_ssd.py: seq_idx propagation in chunked pathmamba3.py/mamba3_simple.py: D parameter bf16 dtype handlingTests
tests/test_mamba3_cpu.py— 45 CPU tests (no GPU required, mocks Triton)tests/test_mamba3_gpu.py— 29 GPU tests (requires H100/A100)tests/test_mamba3_triton.py— 2 Triton forward teststests/test_mamba3_triton_bwd.py— 64 Triton backward gradient tests (12 test classes)Test plan
pytest tests/test_mamba3_cpu.py— 45 CPU testspytest tests/test_mamba3_gpu.py— 29 GPU testspytest tests/test_mamba3_triton.py— 2 Triton forward testspytest tests/test_mamba3_triton_bwd.py— 64 Triton backward tests