Summary
I have an existing MLX Mamba kernel implementation with strong performance results and want to upstream the relevant parts into mlx-lm.
This issue proposes integrating kernel-level optimizations for SSM/Mamba paths used in inference, while preserving current behavior and adding safe fallbacks.
Type: enhancement/performance proposal (not a bug report).
Motivation
Current Mamba/SSM paths are correct and portable, but there is still meaningful room for Apple Silicon-specific speedups in prefill.
Proposed Scope
- Optimize SSM update path for decode (
seq_len == 1) with a faster Metal kernel path.
- Evaluate prefill improvements for chunked scan/selective scan where applicable.
- Keep existing numerics and output contract intact.
- Keep current path as fallback when constraints are not met.
Non-Goals
- No model architecture changes.
- No API changes for users.
- No mandatory dependency changes.
Correctness Plan
- Compare kernel outputs against current implementation with tolerance checks.
- Validate across relevant dtypes used in MLX-LM inference.
- Validate cache/state updates match existing behavior.
Benchmark Plan
I will report:
- Prefill tokens/s and latency
- Decode tokens/s and latency
- End-to-end generation latency on representative prompts
Risks and Mitigations
- Risk: shape/dtype-specific kernel assumptions.
- Mitigation: explicit guards + fallback to existing path.
- Risk: numerical drift.
- Mitigation: parity tests + tolerances documented in PR.
Benchmarks
Environment:
- Device: Apple M1 Max (64 GB)
- Python: 3.12.13
- mlx: 0.31.1
- mlx-metal: 0.31.1
- Baseline: clean
main
- Optimized:
feat/mamba-mlx-kernels
Workload settings:
- Prompt tokens: 1024
- Generation tokens: 128
- Batch size: 1
- Trials: 1
Results (before -> after), focused on confirmed improvements:
-
mlx-community/mamba-370m-hf-f16 (Mamba-1)
- prompt_tps:
293.741 -> 5497.575 (18.72x)
- generation_tps:
159.172 -> 163.434 (1.03x)
- peak_memory (GB):
1.027 -> 2.173
-
mlx-community/mamba2-2.7b (Mamba-2)
- baseline (
main) status: model load fails (ModelArgs missing intermediate_size)
- after model-local compatibility fix:
- prompt_tps:
1071.840
- generation_tps:
48.414
- peak_memory (GB):
7.777
-
mlx-community/Mamba-Codestral-7B-v0.1 (Mamba-2)
- prompt_tps:
382.701 -> 451.243 (1.18x)
- generation_tps:
20.192 -> 21.419 (1.06x)
- peak_memory (GB):
22.918 -> 18.797
-
mlx-community/Falcon3-Mamba-7B-Instruct (Mamba-1)
- prompt_tps:
59.933 -> 513.720 (8.57x)
- generation_tps:
19.325 -> 20.783 (1.08x)
- peak_memory (GB):
15.952 -> 16.624
PR
I will open a PR linked to this issue with:
- Code changes
- Tests
- Benchmarks
Summary
I have an existing MLX Mamba kernel implementation with strong performance results and want to upstream the relevant parts into
mlx-lm.This issue proposes integrating kernel-level optimizations for SSM/Mamba paths used in inference, while preserving current behavior and adding safe fallbacks.
Type: enhancement/performance proposal (not a bug report).
Motivation
Current Mamba/SSM paths are correct and portable, but there is still meaningful room for Apple Silicon-specific speedups in prefill.
Proposed Scope
seq_len == 1) with a faster Metal kernel path.Non-Goals
Correctness Plan
Benchmark Plan
I will report:
Risks and Mitigations
Benchmarks
Environment:
mainfeat/mamba-mlx-kernelsWorkload settings:
Results (before -> after), focused on confirmed improvements:
mlx-community/mamba-370m-hf-f16(Mamba-1)293.741 -> 5497.575(18.72x)159.172 -> 163.434(1.03x)1.027 -> 2.173mlx-community/mamba2-2.7b(Mamba-2)main) status: model load fails (ModelArgsmissingintermediate_size)1071.84048.4147.777mlx-community/Mamba-Codestral-7B-v0.1(Mamba-2)382.701 -> 451.243(1.18x)20.192 -> 21.419(1.06x)22.918 -> 18.797mlx-community/Falcon3-Mamba-7B-Instruct(Mamba-1)59.933 -> 513.720(8.57x)19.325 -> 20.783(1.08x)15.952 -> 16.624PR
I will open a PR linked to this issue with: