Skip to content

[CUDA] GatherQMM matrix-matrix sm80/naive path #3417

Open
Lyxot wants to merge 5 commits intoml-explore:mainfrom
Lyxot:cuda/gather_qmm_cutlass
Open

[CUDA] GatherQMM matrix-matrix sm80/naive path #3417
Lyxot wants to merge 5 commits intoml-explore:mainfrom
Lyxot:cuda/gather_qmm_cutlass

Conversation

@Lyxot
Copy link
Copy Markdown
Contributor

@Lyxot Lyxot commented Apr 16, 2026

Followup to #3321

Proposed Changes

Add optional gather index parameters to qmm_sm80_kernel and qmm_naive_kernel. The kernel reads lhs_indices/rhs_indices at the batch slice lookup.

Performance

RTX 4070 SUPER, fp16, 4-bit affine quantization, K=4096. Baseline is main (qmv-only path).

M N B qmv (us) fused qmm_sm80 Speedup fused qmm_naive Speedup
32 4096 2 1116 116 9.6x 110 10.1x
32 4096 8 4024 280 14.4x 566 7.1x
64 4096 2 1978 157 12.6x 161 12.3x
128 4096 2 3940 223 17.7x 223 17.7x
32 14336 2 3586 245 14.6x 487 7.4x
32 14336 8 13896 821 16.9x 1940 7.2x

Small M stays on qmv with no regression.

Copilot AI review requested due to automatic review settings April 16, 2026 05:01
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds gather-index support to the CUDA quantized matrix-matrix (QMM) implementations so GatherQMM can use the faster SM80 CUTLASS-based kernel (or the naive kernel) by selecting per-output batch slices via lhs_indices / rhs_indices.

Changes:

  • Plumbs optional lhs_indices / rhs_indices pointers through qmm_sm80 and qmm_naive down to the CUDA kernels, and uses them to select the A/B (and scale/bias) batch slices.
  • Updates GatherQMM::eval_gpu to prefer qmm_sm80 / qmm_naive paths (with a small-problem fallback to gather_qmv).
  • Extends the QMM public headers/dispatch to accept the new optional gather-index parameters.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
mlx/backend/cuda/quantized/quantized.cpp Routes GatherQMM to SM80/naive QMM kernels and passes gather index buffers into the QMM calls.
mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh Adds optional gather index parameters to SM80 kernel and uses them for batch slice selection of A/B/scales/biases.
mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh Adds optional gather index parameters to naive kernel and uses them for batch slice selection of A/B/scales/biases.
mlx/backend/cuda/quantized/qmm/qmm.h Extends qmm_sm80 / qmm_naive APIs with optional gather index pointer parameters.
mlx/backend/cuda/quantized/qmm/qmm.cu Wires new optional gather index pointer parameters through the QMM dispatch into qmm_impl_sm80 / qmm_impl_naive.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Lyxot
Copy link
Copy Markdown
Contributor Author

Lyxot commented Apr 16, 2026

an untested sm90 kernel using CUTLASS kArray mode (fused gather via per-batch pointer arrays) 5787da7

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice, thanks!

Comment on lines 47 to +58
void qmm_sm80(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);
cu::CommandEncoder& encoder,
const uint32_t* lhs_indices = nullptr,
const uint32_t* rhs_indices = nullptr);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pass indices as const std::optional<array>& like biases?

void qmm_sm80(
    const array& x,
    const array& w,
    const array& scales,
    const std::optional<array>& biases,
    const std::optional<array>& lhs_indices,
    const std::optional<array>& rhs_indices,
    array& out,
    int bits,
    int group_size,
    QuantizationMode mode,
    cu::CommandEncoder& encoder);

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Apr 17, 2026

On sm90: we don't need to implement the sorted_indices=False version of gather_qmm, because it is not going to be used for model training or inference.

@angeloskath
Copy link
Copy Markdown
Member

On sm90: we don't need to implement the sorted_indices=False version of gather_qmm, because it is not going to be used for model training or inference.

Ideally we want to have full coverage. We shouldn't go out of our way to not implement it. But it is indeed lower priority.

gather_mm (and gather_qmm) are actually quite useful for things other than MoEs as well. Things like block-sparse attention or generally sparse matmuls etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants