Skip to content

feat: Add Mamba-3 implementation (ICLR 2026)#850

Open
payk24 wants to merge 3 commits into
state-spaces:mainfrom
qoxi-cloud:feat/mamba3
Open

feat: Add Mamba-3 implementation (ICLR 2026)#850
payk24 wants to merge 3 commits into
state-spaces:mainfrom
qoxi-cloud:feat/mamba3

Conversation

@payk24
Copy link
Copy Markdown

@payk24 payk24 commented Mar 8, 2026

Summary

Implements the Mamba-3 architecture as described in the ICLR 2026 paper:
https://openreview.net/pdf?id=HwCvaJOiCj

Modules

  • mamba_ssm/modules/mamba3.py — Full Mamba-3 module with tensor parallelism, inference cache, SISO/MIMO support
  • mamba_ssm/modules/mamba3_simple.py — Simplified variant (no TP/inference params)
  • mamba_ssm/ops/triton/mamba3_ssd.py — Chunked parallel SSD kernels + Triton decode kernel (SISO & MIMO)
  • mamba_ssm/__init__.py / mamba_ssm/models/mixer_seq_simple.py — Integration ("layer": "Mamba3" in ssm_cfg)

Triton training kernels (forward + backward)

  • mamba_ssm/ops/triton/mamba3_combined.py — Autograd Function orchestrating Triton forward/backward pipeline
  • mamba_ssm/ops/triton/mamba3_chunk_scan.py — Forward scan kernel with trapezoidal discretization
  • mamba_ssm/ops/triton/mamba3_chunk_scan_bwd.py — Backward kernels: dx, dCB, ddA
  • mamba_ssm/ops/triton/mamba3_chunk_state.py — Forward chunk state with gamma/beta scaling
  • mamba_ssm/ops/triton/mamba3_chunk_state_bwd.py — Backward kernels: dB, ddA (numerically stable)
  • mamba_ssm/ops/triton/mamba3_rope.py — RoPE application for B/C tensors
  • mamba_ssm/ops/triton/mamba3_shift.py — Shifted B/x for trapezoidal lookback term

Architecture: reuses 6 Mamba-2 backward kernels (dz, dstates, state_passing, dC, cumsum, bmm) — only Mamba-3-specific ops get new Triton kernels.

Key features

  • Chunked parallel SSD with trapezoidal discretization, RoPE, and BCNorm
  • SISO Triton path with automatic PyTorch fallback for MIMO
  • seq_idx support for packed multi-document training
  • initial_prev_Bx backward through state + output corrections
  • Triton decode kernel for efficient autoregressive generation (SISO & MIMO)
  • Gradient checkpointing (use_mem_eff_path)
  • Prefill → decode state handoff

Bug fixes

  • ssd_state_passing.py: int32/int64 loop variable mismatch with seq_idx
  • mamba3_ssd.py: seq_idx propagation in chunked path
  • mamba3.py / mamba3_simple.py: D parameter bf16 dtype handling

Tests

  • 138 tests total, all passing
  • tests/test_mamba3_cpu.py — 45 CPU tests (no GPU required, mocks Triton)
  • tests/test_mamba3_gpu.py — 29 GPU tests (requires H100/A100)
  • tests/test_mamba3_triton.py — 2 Triton forward tests
  • tests/test_mamba3_triton_bwd.py — 64 Triton backward gradient tests (12 test classes)

Test plan

  • pytest tests/test_mamba3_cpu.py — 45 CPU tests
  • pytest tests/test_mamba3_gpu.py — 29 GPU tests
  • pytest tests/test_mamba3_triton.py — 2 Triton forward tests
  • pytest tests/test_mamba3_triton_bwd.py — 64 Triton backward tests

payk24 added 2 commits March 8, 2026 20:19
Mamba-3 module with chunked parallel SSD, RoPE, BCNorm, SISO/MIMO support,
seq_idx for packed multi-document training, Triton decode kernel, and
gradient checkpointing. Includes full and simplified module variants,
Triton ops, and comprehensive CPU/GPU test suites (74 tests).
Triton-accelerated forward and backward for Mamba-3 chunked SSD:

New Triton kernels (5 new files):
- mamba3_chunk_scan.py: Forward scan kernel with trapezoidal support
- mamba3_chunk_scan_bwd.py: Backward kernels for dx, dCB, ddA
- mamba3_chunk_state.py: Forward chunk state with gamma/beta scaling
- mamba3_chunk_state_bwd.py: Backward kernels for dB and ddA (stable)
- mamba3_combined.py: Autograd Function orchestrating Triton pipeline

Supporting kernels:
- mamba3_rope.py: RoPE application for B/C tensors
- mamba3_shift.py: Shifted B/x computation for trapezoidal lookback

Key design decisions:
- Reuses 6 Mamba-2 backward kernels (dz, dstates, state_passing, dC,
  cumsum, bmm) — only Mamba-3-specific ops get new kernels
- SISO Triton path with PyTorch fallback for MIMO
- initial_prev_Bx backward: explicit gradient through state + output
  corrections (additive, so existing kernels handle other params)
- RoPE backward via PyTorch autograd recompute (general, correct)

Bug fixes:
- ssd_state_passing.py: int32/int64 loop variable mismatch with seq_idx
- mamba3_ssd.py: seq_idx propagation in chunked path
- mamba3.py/mamba3_simple.py: D parameter bf16 dtype handling

Tests: 64 Triton backward + 2 Triton forward tests (all passing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant