Skip to content

Commit 6875814

Browse files
Metal backend: Add gather_qmv kernel for MoE expert-indexed quantized matmul (#18877)
Adds gather_qmv Metal kernel for Mixture-of-Experts: performs per-expert quantized matrix-vector multiply y[i] = W[expert_idx[i]] @ x[i]. Extends the existing qmv kernels in op_linear_4bit.mm with expert index-based pointer offsets, following the same pattern as MLX's affine_gather_qmv_fast. Two dispatch paths (matching op_linear_4bit.mm): - gather_qmv_fast: optimized path for K%512==0 and N%8==0 - gather_qmv_impl: generic fallback for any K and N Uses the same affine INT4 dequantization format as op_linear_4bit.mm (scale * accum + sum * bias). Instantiated for 4-bit with group sizes {32, 64, 128} and dtypes {float, bfloat16}. Includes: Metal shader + C++ host dispatch, Python custom op definition (metal::gather_qmv) with reference CPU impl and Meta impl, C shim dict, fallback kernel registration, CMakeLists entry, and test module.
1 parent ccaf17e commit 6875814

6 files changed

Lines changed: 774 additions & 1 deletion

File tree

backends/apple/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ set(_aoti_metal_sources
4545
runtime/ops/common.mm
4646
runtime/ops/op_bmm.mm
4747
runtime/ops/op_convolution.mm
48+
runtime/ops/op_gather_qmv.mm
4849
runtime/ops/op_linear_4bit.mm
4950
runtime/ops/op_mm.mm
5051
runtime/ops/op_sdpa.mm

backends/apple/metal/metal_backend.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3737
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
3838
"torchao::_linear_fp_act_4bit_weight": None,
3939
"at::_ops::topk::call": None,
40+
"metal::gather_qmv": None,
4041
}
4142

4243
@classmethod
@@ -76,6 +77,17 @@ def get_aoti_compile_options(
7677

7778
from torchao.experimental.ops.mps.cshim import torchao_op_c_shim
7879

79-
inductor_configs["aot_inductor.custom_ops_to_c_shims"] = torchao_op_c_shim
80+
custom_c_shims = {**torchao_op_c_shim}
81+
82+
try:
83+
from executorch.backends.apple.metal.ops.gather_qmv import (
84+
metal_gather_qmv_c_shim,
85+
)
86+
87+
custom_c_shims.update(metal_gather_qmv_c_shim)
88+
except ImportError:
89+
pass
90+
91+
inductor_configs["aot_inductor.custom_ops_to_c_shims"] = custom_c_shims
8092

8193
return inductor_configs
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
metal::gather_qmv custom op for MoE expert-indexed quantized matmul.
9+
10+
Performs y[i] = W[expert_idx[i]] @ x[i] with INT4 quantized expert weights.
11+
The Metal fallback kernel is in runtime/ops/op_gather_qmv.mm.
12+
"""
13+
14+
import torch
15+
from torch import Tensor
16+
17+
18+
@torch.library.custom_op("metal::gather_qmv", mutates_args=())
19+
def gather_qmv(
20+
x: Tensor, # [P, K] — activations (P = num token-expert pairs)
21+
w: Tensor, # [E, N, K_packed] — packed INT4 expert weights
22+
scales: Tensor, # [E, N, K/gs] — per-group scales
23+
biases: Tensor, # [E, N, K/gs] — per-group biases
24+
expert_indices: Tensor, # [P] — expert index per pair
25+
group_size: int,
26+
) -> Tensor:
27+
"""Reference implementation for tracing and CPU testing."""
28+
P, K = x.shape
29+
E, N, K_packed = w.shape
30+
31+
y = torch.zeros(P, N, dtype=x.dtype, device=x.device)
32+
for i in range(P):
33+
eidx = expert_indices[i].item()
34+
w_e = w[eidx] # [N, K_packed]
35+
s_e = scales[eidx] # [N, K/gs]
36+
b_e = biases[eidx] # [N, K/gs]
37+
38+
# Dequantize: unpack INT4, apply affine dequant
39+
w_unpacked = _dequantize_int4_affine(w_e, s_e, b_e, K, group_size)
40+
y[i] = w_unpacked @ x[i]
41+
42+
return y
43+
44+
45+
def _quantize_int4_affine(w: Tensor, group_size: int) -> tuple[Tensor, Tensor, Tensor]:
46+
"""Quantize float weights to packed INT4 using MLX affine format.
47+
48+
Args:
49+
w: [..., K] float weight tensor (last dim is quantized).
50+
group_size: Number of elements per quantization group.
51+
52+
Returns:
53+
(packed, scales, biases) where:
54+
- packed: [..., K//2] uint8, two INT4 values per byte.
55+
- scales: [..., K//group_size] per-group scales.
56+
- biases: [..., K//group_size] per-group biases (zero points).
57+
58+
The affine mapping is: dequantized = raw_uint4 * scale + bias,
59+
where raw_uint4 is in [0, 15].
60+
"""
61+
*leading, K = w.shape
62+
w_groups = w.reshape(*leading, K // group_size, group_size)
63+
g_min = w_groups.amin(dim=-1)
64+
g_max = w_groups.amax(dim=-1)
65+
scales = ((g_max - g_min) / 15.0).clamp(min=1e-8)
66+
biases = g_min
67+
w_int = (
68+
((w_groups - biases.unsqueeze(-1)) / scales.unsqueeze(-1))
69+
.round()
70+
.clamp(0, 15)
71+
.to(torch.uint8)
72+
.reshape(*leading, K)
73+
)
74+
packed = w_int[..., 0::2] | (w_int[..., 1::2] << 4)
75+
return packed, scales, biases
76+
77+
78+
def _dequantize_int4_affine(
79+
w_packed: Tensor, scales: Tensor, biases: Tensor, K: int, group_size: int
80+
) -> Tensor:
81+
"""Dequantize packed INT4 weights using MLX affine format."""
82+
N = w_packed.shape[0]
83+
w_bytes = w_packed.to(torch.int16)
84+
low = w_bytes & 0x0F
85+
high = (w_bytes >> 4) & 0x0F
86+
w_int = torch.stack([low, high], dim=-1).reshape(N, K).float()
87+
88+
scales_expanded = scales.float().repeat_interleave(group_size, dim=-1)[:, :K]
89+
biases_expanded = biases.float().repeat_interleave(group_size, dim=-1)[:, :K]
90+
91+
return (w_int * scales_expanded + biases_expanded).to(scales.dtype)
92+
93+
94+
@torch.library.register_fake("metal::gather_qmv")
95+
def gather_qmv_fake(
96+
x: Tensor,
97+
w: Tensor,
98+
scales: Tensor,
99+
biases: Tensor,
100+
expert_indices: Tensor,
101+
group_size: int,
102+
) -> Tensor:
103+
P = x.shape[0]
104+
N = w.shape[1]
105+
return torch.empty(P, N, dtype=x.dtype, device=x.device)
106+
107+
108+
# C shim mapping for AOTInductor code generation.
109+
# Maps the torch op to the C function name that the generated wrapper calls.
110+
metal_gather_qmv_c_shim = {
111+
torch.ops.metal.gather_qmv.default: [
112+
"AOTITorchError aoti_torch_mps_gather_qmv("
113+
"AtenTensorHandle X, AtenTensorHandle W, AtenTensorHandle S, "
114+
"AtenTensorHandle Z, AtenTensorHandle ExpertIndices, "
115+
"int64_t group_size, AtenTensorHandle* ret)"
116+
],
117+
}

0 commit comments

Comments
 (0)