Metal backend: Add gather_qmv kernel for MoE expert-indexed quantized matmul#18877
Metal backend: Add gather_qmv kernel for MoE expert-indexed quantized matmul#18877manuelcandales merged 19 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18877
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Cancelled Job, 4 Unrelated FailuresAs of commit 7423226 with merge base 66e4656 ( NEW FAILURE - The following job has failed:
CANCELLED JOB - The following job was cancelled. Please retry:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| "description": "Expert-indexed quantized matmul for MoE (metal::gather_qmv)", | ||
| "atol_float32": 5e-2, | ||
| "rtol_float32": 5e-2, | ||
| "atol_bfloat16": 5.0, |
There was a problem hiding this comment.
Why is this tolerance so large?
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 54s —— View job Code Review: Metal gather_qmv kernel for MoE
Overall the implementation is clean and follows the existing Issues1.
int32_t E = static_cast<int32_t>(w_tensor->sizes()[0]); // extracted but never checkedConsider adding validation of 2. Missing validation of The code validates 3. Large test tolerance (responds to @metascroy's review comment) — "atol_float32": 5e-2,
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,These tolerances are large because: (a) the test compares the quantized kernel output against the quantized reference CPU implementation (both go through quantize → dequantize → matmul), but the Metal shader uses the 4-bit bitmask/predivision trick ( That said, these tolerances are identical to the existing Minor Observations4. Code duplication of 4-bit quantization helpers The 5. This file is just a license header with no imports. It exists to make 6. Backend integration is cleanly guarded — The Correctness Verification
SummaryThe implementation is solid and follows established patterns. The main suggestions are adding input validation for |
This PR needs a
|
Adds gather_qmv Metal kernel for Mixture-of-Experts: performs per-expert
quantized matrix-vector multiply y[i] = W[expert_idx[i]] @ x[i]. Extends
the existing qmv kernels in op_linear_4bit.mm with expert
index-based pointer offsets, following the same pattern as MLX's
affine_gather_qmv_fast.
Two dispatch paths (matching op_linear_4bit.mm):
Uses the same affine INT4 dequantization format as op_linear_4bit.mm
(scale * accum + sum * bias). Instantiated for 4-bit with group sizes
{32, 64, 128} and dtypes {float, bfloat16}.
Includes: Metal shader + C++ host dispatch, Python custom op definition
(metal::gather_qmv) with reference CPU impl and Meta impl, C shim dict,
fallback kernel registration, CMakeLists entry, and test module.
Authored with Claude.