Skip to content

Commit 47cbe76

Browse files
Update
[ghstack-poisoned]
1 parent 1be53ab commit 47cbe76

6 files changed

Lines changed: 526 additions & 1 deletion

File tree

backends/apple/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ set(_aoti_metal_sources
4545
runtime/ops/common.mm
4646
runtime/ops/op_bmm.mm
4747
runtime/ops/op_convolution.mm
48+
runtime/ops/op_gather_qmv.mm
4849
runtime/ops/op_linear_4bit.mm
4950
runtime/ops/op_mm.mm
5051
runtime/ops/op_sdpa.mm

backends/apple/metal/metal_backend.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3737
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
3838
"torchao::_linear_fp_act_4bit_weight": None,
3939
"at::_ops::topk::call": None,
40+
"metal::gather_qmv": None,
4041
}
4142

4243
@classmethod
@@ -76,6 +77,17 @@ def get_aoti_compile_options(
7677

7778
from torchao.experimental.ops.mps.cshim import torchao_op_c_shim
7879

79-
inductor_configs["aot_inductor.custom_ops_to_c_shims"] = torchao_op_c_shim
80+
custom_c_shims = {**torchao_op_c_shim}
81+
82+
try:
83+
from executorch.backends.apple.metal.ops.gather_qmv import (
84+
metal_gather_qmv_c_shim,
85+
)
86+
87+
custom_c_shims.update(metal_gather_qmv_c_shim)
88+
except ImportError:
89+
pass
90+
91+
inductor_configs["aot_inductor.custom_ops_to_c_shims"] = custom_c_shims
8092

8193
return inductor_configs
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
metal::gather_qmv custom op for MoE expert-indexed quantized matmul.
9+
10+
Performs y[i] = W[expert_idx[i]] @ x[i] with INT4 quantized expert weights.
11+
The Metal fallback kernel is in runtime/ops/op_gather_qmv.mm.
12+
"""
13+
14+
import torch
15+
from torch import Tensor
16+
17+
18+
@torch.library.custom_op("metal::gather_qmv", mutates_args=())
19+
def gather_qmv(
20+
x: Tensor, # [P, K] — activations (P = num token-expert pairs)
21+
w: Tensor, # [E, N, K_packed] — packed INT4 expert weights
22+
scales: Tensor, # [E, N, K/gs] — per-group scales
23+
biases: Tensor, # [E, N, K/gs] — per-group biases
24+
expert_indices: Tensor, # [P] — expert index per pair
25+
group_size: int,
26+
) -> Tensor:
27+
"""Reference implementation for tracing and CPU testing."""
28+
P, K = x.shape
29+
E, N, K_packed = w.shape
30+
31+
y = torch.zeros(P, N, dtype=x.dtype, device=x.device)
32+
for i in range(P):
33+
eidx = expert_indices[i].item()
34+
w_e = w[eidx] # [N, K_packed]
35+
s_e = scales[eidx] # [N, K/gs]
36+
b_e = biases[eidx] # [N, K/gs]
37+
38+
# Dequantize: unpack INT4, apply affine dequant
39+
w_unpacked = _dequantize_int4_affine(w_e, s_e, b_e, K, group_size)
40+
y[i] = w_unpacked @ x[i]
41+
42+
return y
43+
44+
45+
def _dequantize_int4_affine(
46+
w_packed: Tensor, scales: Tensor, biases: Tensor, K: int, group_size: int
47+
) -> Tensor:
48+
"""Dequantize packed INT4 weights using MLX affine format."""
49+
N = w_packed.shape[0]
50+
w_bytes = w_packed.to(torch.int16)
51+
low = w_bytes & 0x0F
52+
high = (w_bytes >> 4) & 0x0F
53+
w_int = torch.stack([low, high], dim=-1).reshape(N, K).float()
54+
55+
scales_expanded = scales.float().repeat_interleave(group_size, dim=-1)[:, :K]
56+
biases_expanded = biases.float().repeat_interleave(group_size, dim=-1)[:, :K]
57+
58+
return (w_int * scales_expanded + biases_expanded).to(scales.dtype)
59+
60+
61+
@torch.library.register_fake("metal::gather_qmv")
62+
def gather_qmv_fake(
63+
x: Tensor,
64+
w: Tensor,
65+
scales: Tensor,
66+
biases: Tensor,
67+
expert_indices: Tensor,
68+
group_size: int,
69+
) -> Tensor:
70+
P = x.shape[0]
71+
N = w.shape[1]
72+
return torch.empty(P, N, dtype=x.dtype, device=x.device)
73+
74+
75+
# C shim mapping for AOTInductor code generation.
76+
# Maps the torch op to the C function name that the generated wrapper calls.
77+
metal_gather_qmv_c_shim = {
78+
torch.ops.metal.gather_qmv.default: [
79+
"AOTITorchError aoti_torch_mps_gather_qmv("
80+
"AtenTensorHandle X, AtenTensorHandle W, AtenTensorHandle S, "
81+
"AtenTensorHandle Z, AtenTensorHandle ExpertIndices, "
82+
"int64_t group_size, AtenTensorHandle* ret)"
83+
],
84+
}

0 commit comments

Comments
 (0)