Skip to content

Optimize Mamba/SSM decode and prefill kernels on Apple Silicon #1152

@Gal-bloch

Description

@Gal-bloch

Summary

I have an existing MLX Mamba kernel implementation with strong performance results and want to upstream the relevant parts into mlx-lm.

This issue proposes integrating kernel-level optimizations for SSM/Mamba paths used in inference, while preserving current behavior and adding safe fallbacks.

Type: enhancement/performance proposal (not a bug report).

Motivation

Current Mamba/SSM paths are correct and portable, but there is still meaningful room for Apple Silicon-specific speedups in prefill.

Proposed Scope

  • Optimize SSM update path for decode (seq_len == 1) with a faster Metal kernel path.
  • Evaluate prefill improvements for chunked scan/selective scan where applicable.
  • Keep existing numerics and output contract intact.
  • Keep current path as fallback when constraints are not met.

Non-Goals

  • No model architecture changes.
  • No API changes for users.
  • No mandatory dependency changes.

Correctness Plan

  • Compare kernel outputs against current implementation with tolerance checks.
  • Validate across relevant dtypes used in MLX-LM inference.
  • Validate cache/state updates match existing behavior.

Benchmark Plan

I will report:

  • Prefill tokens/s and latency
  • Decode tokens/s and latency
  • End-to-end generation latency on representative prompts

Risks and Mitigations

  • Risk: shape/dtype-specific kernel assumptions.
    • Mitigation: explicit guards + fallback to existing path.
  • Risk: numerical drift.
    • Mitigation: parity tests + tolerances documented in PR.

Benchmarks

Environment:

  • Device: Apple M1 Max (64 GB)
  • Python: 3.12.13
  • mlx: 0.31.1
  • mlx-metal: 0.31.1
  • Baseline: clean main
  • Optimized: feat/mamba-mlx-kernels

Workload settings:

  • Prompt tokens: 1024
  • Generation tokens: 128
  • Batch size: 1
  • Trials: 1

Results (before -> after), focused on confirmed improvements:

  • mlx-community/mamba-370m-hf-f16 (Mamba-1)

    • prompt_tps: 293.741 -> 5497.575 (18.72x)
    • generation_tps: 159.172 -> 163.434 (1.03x)
    • peak_memory (GB): 1.027 -> 2.173
  • mlx-community/mamba2-2.7b (Mamba-2)

    • baseline (main) status: model load fails (ModelArgs missing intermediate_size)
    • after model-local compatibility fix:
      • prompt_tps: 1071.840
      • generation_tps: 48.414
      • peak_memory (GB): 7.777
  • mlx-community/Mamba-Codestral-7B-v0.1 (Mamba-2)

    • prompt_tps: 382.701 -> 451.243 (1.18x)
    • generation_tps: 20.192 -> 21.419 (1.06x)
    • peak_memory (GB): 22.918 -> 18.797
  • mlx-community/Falcon3-Mamba-7B-Instruct (Mamba-1)

    • prompt_tps: 59.933 -> 513.720 (8.57x)
    • generation_tps: 19.325 -> 20.783 (1.08x)
    • peak_memory (GB): 15.952 -> 16.624

PR

I will open a PR linked to this issue with:

  • Code changes
  • Tests
  • Benchmarks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions