Skip to content

Metal backend: Add gather_qmv kernel for MoE expert-indexed quantized matmul#18877

Merged
manuelcandales merged 19 commits intomainfrom
gh/manuelcandales/172/head
Apr 21, 2026
Merged

Metal backend: Add gather_qmv kernel for MoE expert-indexed quantized matmul#18877
manuelcandales merged 19 commits intomainfrom
gh/manuelcandales/172/head

Conversation

@manuelcandales
Copy link
Copy Markdown
Contributor

@manuelcandales manuelcandales commented Apr 14, 2026

Adds gather_qmv Metal kernel for Mixture-of-Experts: performs per-expert
quantized matrix-vector multiply y[i] = W[expert_idx[i]] @ x[i]. Extends
the existing qmv kernels in op_linear_4bit.mm with expert
index-based pointer offsets, following the same pattern as MLX's
affine_gather_qmv_fast.

Two dispatch paths (matching op_linear_4bit.mm):

  • gather_qmv_fast: optimized path for K%512==0 and N%8==0
  • gather_qmv_impl: generic fallback for any K and N

Uses the same affine INT4 dequantization format as op_linear_4bit.mm
(scale * accum + sum * bias). Instantiated for 4-bit with group sizes
{32, 64, 128} and dtypes {float, bfloat16}.

Includes: Metal shader + C++ host dispatch, Python custom op definition
(metal::gather_qmv) with reference CPU impl and Meta impl, C shim dict,
fallback kernel registration, CMakeLists entry, and test module.

Authored with Claude.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18877

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 4 Unrelated Failures

As of commit 7423226 with merge base 66e4656 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@manuelcandales manuelcandales marked this pull request as ready for review April 14, 2026 22:24
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
"description": "Expert-indexed quantized matmul for MoE (metal::gather_qmv)",
"atol_float32": 5e-2,
"rtol_float32": 5e-2,
"atol_bfloat16": 5.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this tolerance so large?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

[ghstack-poisoned]
[ghstack-poisoned]
@metascroy
Copy link
Copy Markdown
Contributor

@claude review this code

@claude
Copy link
Copy Markdown

claude Bot commented Apr 20, 2026

Claude finished @metascroy's task in 3m 54s —— View job


Code Review: Metal gather_qmv kernel for MoE

  • Read all changed files
  • Review Metal shader + C++ host dispatch (op_gather_qmv.mm)
  • Review Python custom op definition (gather_qmv.py)
  • Review backend integration (metal_backend.py)
  • Review tests (test_modules.py)
  • Review build integration (CMakeLists.txt, __init__.py)
  • Post review feedback

Overall the implementation is clean and follows the existing op_linear_4bit.mm patterns well. The expert-index offset approach is straightforward, and the code is well-structured. Below are my findings:

Issues

1. E extracted but never validated against expert_indicesop_gather_qmv.mm:435

E is read from the weight tensor shape but never used for validation. If any value in expert_indices is >= E, the shader will access out-of-bounds memory. While the caller is expected to provide valid indices, adding a host-side bounds check (or at minimum an assertion) would be cheap and defensive, especially since this is a new kernel that may be called with user-controlled routing decisions.

int32_t E = static_cast<int32_t>(w_tensor->sizes()[0]);  // extracted but never checked

Consider adding validation of idx_tensor dimensions (should be 1D with size P) and optionally a debug-mode bounds check on the values.

2. Missing validation of scales/biases tensor shapes — op_gather_qmv.mm:417-421

The code validates x (2D) and w (3D) dimensions but doesn't validate that scales and biases have the expected shape [E, N, K/gs], or that expert_indices has shape [P]. The op_linear_4bit.mm also doesn't validate these, so this is consistent, but it would be good to add for a new kernel.

3. Large test tolerance (responds to @metascroy's review comment) — test_modules.py:738-741

"atol_float32": 5e-2,
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,

These tolerances are large because: (a) the test compares the quantized kernel output against the quantized reference CPU implementation (both go through quantize → dequantize → matmul), but the Metal shader uses the 4-bit bitmask/predivision trick (x/16, x/256, x/4096) which introduces floating-point rounding differences vs. the straightforward CPU dequant; and (b) the accumulation order differs (SIMD reduction across 32 lanes vs. sequential).

That said, these tolerances are identical to the existing linear_nobias_int4 / linear_bias_int4 tests (lines 245-248, 267-270), so they're consistent with the project. If the concern is that 5e-2 atol on float32 is too loose, a focused test with known weights (e.g., identity-like patterns) that computes the exact expected output analytically could provide tighter validation.

Minor Observations

4. Code duplication of 4-bit quantization helpers

The load_vector_4bit, qdot_4bit, load_vector_safe_4bit, and qdot_safe_4bit functions in op_gather_qmv.mm are copy-pasted specializations of the generic load_vector, qdot, etc. from op_linear_4bit.mm. This is understandable since the gather kernel only needs 4-bit, but it does introduce two copies of the same logic to maintain. Consider whether these could be shared (e.g., via a common Metal header included by both shader sources). Not blocking — the duplication is small and self-contained.

5. __init__.py is empty (minus license) — ops/__init__.py

This file is just a license header with no imports. It exists to make ops/ a package, which is fine. Just noting it's intentionally empty.

6. Backend integration is cleanly guarded — metal_backend.py:82-89

The try/except ImportError around the gather_qmv c_shim import is a good pattern that avoids breaking the backend if the op module isn't available.

Correctness Verification

  • Expert stride computation (op_gather_qmv.mm:516-522): w_stride = N * K_packed (bytes for uint8 weight tensor), s_stride = N * K_g (elements for T-typed scales). Both are correct for contiguous [E, N, ...] layout.
  • Grid dispatch (op_gather_qmv.mm:538-544): (P, (N+7)/8, 1) with threadgroup (32, 2, 1) — matches the results_per_simdgroup=4 × num_simdgroups=2 = 8 rows per threadgroup, so (N+7)/8 threadgroups covers all N output rows. Correct.
  • Python reference impl (gather_qmv.py:28-42): Correctly iterates over P pairs, dequantizes per-expert weights, and computes W @ x. The dequant logic matches the Metal shader's affine format.
  • Meta (fake) impl (gather_qmv.py:92-103): Returns correct shape [P, N] with matching dtype.

Summary

The implementation is solid and follows established patterns. The main suggestions are adding input validation for expert_indices dimensions and bounds (items 1-2 above), and potentially tightening or documenting the test tolerance rationale. The Metal shader logic, expert stride computation, and Python op registration are all correct.


[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Base automatically changed from gh/manuelcandales/171/head to main April 20, 2026 19:55
[ghstack-poisoned]
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@manuelcandales manuelcandales merged commit 6875814 into main Apr 21, 2026
189 of 195 checks passed
@manuelcandales manuelcandales deleted the gh/manuelcandales/172/head branch April 21, 2026 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants