[CUDA] GatherQMM matrix-matrix sm80/naive path #3417
[CUDA] GatherQMM matrix-matrix sm80/naive path #3417Lyxot wants to merge 5 commits intoml-explore:mainfrom
Conversation
The pre-gather approach (gather_slices + qmm_sm90) is a temporary workaround that benchmarks 2-4x slower than the fused sm80 kernel.
There was a problem hiding this comment.
Pull request overview
Adds gather-index support to the CUDA quantized matrix-matrix (QMM) implementations so GatherQMM can use the faster SM80 CUTLASS-based kernel (or the naive kernel) by selecting per-output batch slices via lhs_indices / rhs_indices.
Changes:
- Plumbs optional
lhs_indices/rhs_indicespointers throughqmm_sm80andqmm_naivedown to the CUDA kernels, and uses them to select the A/B (and scale/bias) batch slices. - Updates
GatherQMM::eval_gputo preferqmm_sm80/qmm_naivepaths (with a small-problem fallback togather_qmv). - Extends the QMM public headers/dispatch to accept the new optional gather-index parameters.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
mlx/backend/cuda/quantized/quantized.cpp |
Routes GatherQMM to SM80/naive QMM kernels and passes gather index buffers into the QMM calls. |
mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh |
Adds optional gather index parameters to SM80 kernel and uses them for batch slice selection of A/B/scales/biases. |
mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh |
Adds optional gather index parameters to naive kernel and uses them for batch slice selection of A/B/scales/biases. |
mlx/backend/cuda/quantized/qmm/qmm.h |
Extends qmm_sm80 / qmm_naive APIs with optional gather index pointer parameters. |
mlx/backend/cuda/quantized/qmm/qmm.cu |
Wires new optional gather index pointer parameters through the QMM dispatch into qmm_impl_sm80 / qmm_impl_naive. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
an untested sm90 kernel using CUTLASS kArray mode (fused gather via per-batch pointer arrays) 5787da7 |
zcbenz
left a comment
There was a problem hiding this comment.
This is very nice, thanks!
| void qmm_sm80( | ||
| const array& x, | ||
| const array& w, | ||
| const array& scales, | ||
| const std::optional<array>& biases, | ||
| array& out, | ||
| int bits, | ||
| int group_size, | ||
| QuantizationMode mode, | ||
| cu::CommandEncoder& encoder); | ||
| cu::CommandEncoder& encoder, | ||
| const uint32_t* lhs_indices = nullptr, | ||
| const uint32_t* rhs_indices = nullptr); |
There was a problem hiding this comment.
Can you pass indices as const std::optional<array>& like biases?
void qmm_sm80(
const array& x,
const array& w,
const array& scales,
const std::optional<array>& biases,
const std::optional<array>& lhs_indices,
const std::optional<array>& rhs_indices,
array& out,
int bits,
int group_size,
QuantizationMode mode,
cu::CommandEncoder& encoder);|
On sm90: we don't need to implement the |
Ideally we want to have full coverage. We shouldn't go out of our way to not implement it. But it is indeed lower priority.
|
Followup to #3321
Proposed Changes
Add optional gather index parameters to
qmm_sm80_kernelandqmm_naive_kernel. The kernel readslhs_indices/rhs_indicesat the batch slice lookup.Performance
RTX 4070 SUPER, fp16, 4-bit affine quantization, K=4096. Baseline is
main(qmv-only path).Small M stays on qmv with no regression.