Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a3a42e4
Update
manuelcandales Apr 14, 2026
1c965c6
Update
manuelcandales Apr 14, 2026
1be53ab
Update
manuelcandales Apr 14, 2026
47cbe76
Update
manuelcandales Apr 14, 2026
805a09d
Update
manuelcandales Apr 14, 2026
5306c5a
Update
manuelcandales Apr 14, 2026
958712e
Update
manuelcandales Apr 14, 2026
eba74c4
Update
manuelcandales Apr 14, 2026
c222005
Update
manuelcandales Apr 14, 2026
e7a7acc
Update
manuelcandales Apr 14, 2026
5530242
Update
manuelcandales Apr 14, 2026
59f88db
Update
manuelcandales Apr 14, 2026
1fbb94f
Update
manuelcandales Apr 14, 2026
60ca500
Update
manuelcandales Apr 14, 2026
d80da37
Update
manuelcandales Apr 14, 2026
4632a83
Update
manuelcandales Apr 20, 2026
98d2f81
Update
manuelcandales Apr 20, 2026
95fb7f9
Update
manuelcandales Apr 20, 2026
440f7fc
Update
manuelcandales Apr 20, 2026
f4f616e
Update
manuelcandales Apr 20, 2026
b8e1201
Update
manuelcandales Apr 20, 2026
9ce837a
Update
manuelcandales Apr 20, 2026
248115a
Update
manuelcandales Apr 20, 2026
ee865c3
Update
manuelcandales Apr 20, 2026
36d45ef
Update
manuelcandales Apr 20, 2026
9000488
Update
manuelcandales Apr 20, 2026
a060d19
Update
manuelcandales Apr 20, 2026
01c3ce5
Update
manuelcandales Apr 20, 2026
0c1a88b
Update
manuelcandales Apr 20, 2026
2c56804
Update
manuelcandales Apr 20, 2026
933122c
Update
manuelcandales Apr 20, 2026
9def0ed
Update
manuelcandales Apr 20, 2026
01ecf6a
Update
manuelcandales Apr 20, 2026
1766789
Update
manuelcandales Apr 20, 2026
7423226
Update
manuelcandales Apr 20, 2026
4b791ea
Update
manuelcandales Apr 20, 2026
ff92256
Update
manuelcandales Apr 20, 2026
f8ebcfb
Update
manuelcandales Apr 21, 2026
4cf31c8
Update
manuelcandales Apr 21, 2026
187e4f5
Update
manuelcandales Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/apple/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ set(_aoti_metal_sources
runtime/ops/common.mm
runtime/ops/op_bmm.mm
runtime/ops/op_convolution.mm
runtime/ops/op_gather_qmv.mm
runtime/ops/op_gated_delta_rule.mm
runtime/ops/op_linear_4bit.mm
runtime/ops/op_mm.mm
runtime/ops/op_sdpa.mm
runtime/ops/op_topk.mm
)

add_library(metal_backend STATIC ${_aoti_metal_sources})
Expand Down
25 changes: 24 additions & 1 deletion backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
"aoti_torch_mps_mm_out": None,
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
"torchao::_linear_fp_act_4bit_weight": None,
"at::_ops::topk::call": None,
"metal::gather_qmv": None,
"metal::gated_delta_rule": None,
}

@classmethod
Expand Down Expand Up @@ -75,6 +78,26 @@ def get_aoti_compile_options(

from torchao.experimental.ops.mps.cshim import torchao_op_c_shim

inductor_configs["aot_inductor.custom_ops_to_c_shims"] = torchao_op_c_shim
custom_c_shims = {**torchao_op_c_shim}

try:
from executorch.backends.apple.metal.ops.gather_qmv import (
metal_gather_qmv_c_shim,
)

custom_c_shims.update(metal_gather_qmv_c_shim)
except ImportError:
pass

try:
from executorch.backends.apple.metal.ops.gated_delta_rule import (
metal_gated_delta_rule_c_shim,
)

custom_c_shims.update(metal_gated_delta_rule_c_shim)
except ImportError:
pass

inductor_configs["aot_inductor.custom_ops_to_c_shims"] = custom_c_shims

return inductor_configs
5 changes: 5 additions & 0 deletions backends/apple/metal/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
92 changes: 92 additions & 0 deletions backends/apple/metal/ops/gated_delta_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
metal::gated_delta_rule custom op for linear attention recurrence.

Performs the gated delta rule recurrence over T time steps, mutating
the recurrent state in-place. The Metal fallback kernel is in
runtime/ops/op_gated_delta_rule.mm.
"""

import torch
from torch import Tensor


@torch.library.custom_op("metal::gated_delta_rule", mutates_args=("state",))
def gated_delta_rule(
q: Tensor, # [B, T, Hk, Dk]
k: Tensor, # [B, T, Hk, Dk]
v: Tensor, # [B, T, Hv, Dv]
g: Tensor, # [B, T, Hv] — decay gate (already exp'd)
beta: Tensor, # [B, T, Hv] — update gate
state: Tensor, # [B, Hv, Dv, Dk] — recurrent state (MUTATED)
) -> Tensor:
"""Reference implementation: sequential recurrence over T."""
B, T_len, Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]

s = state.clone().float()
ys = []

assert Hv % Hk == 0, f"Hv ({Hv}) must be divisible by Hk ({Hk})"
hk_repeat = Hv // Hk

for t in range(T_len):
q_t = q[:, t].float() # [B, Hk, Dk]
k_t = k[:, t].float() # [B, Hk, Dk]
v_t = v[:, t].float() # [B, Hv, Dv]
g_t = g[:, t].float() # [B, Hv]
beta_t = beta[:, t].float() # [B, Hv]

# Expand keys to match value heads (GQA: Hk -> Hv)
if hk_repeat > 1:
q_t = q_t.repeat_interleave(hk_repeat, dim=1) # [B, Hv, Dk]
k_t = k_t.repeat_interleave(hk_repeat, dim=1) # [B, Hv, Dk]

# Decay
s = s * g_t[:, :, None, None]

# Project state by key
kv_mem = (s * k_t[:, :, None, :]).sum(dim=-1) # [B, Hv, Dv]

# Delta rule update
delta = (v_t - kv_mem) * beta_t[:, :, None] # [B, Hv, Dv]
s = s + k_t[:, :, None, :] * delta[:, :, :, None] # [B, Hv, Dv, Dk]

# Read from state
y_t = (s * q_t[:, :, None, :]).sum(dim=-1) # [B, Hv, Dv]
ys.append(y_t)

state.copy_(s.to(state.dtype))
return torch.stack(ys, dim=1).to(q.dtype)


@torch.library.register_fake("metal::gated_delta_rule")
def gated_delta_rule_fake(
q: Tensor,
k: Tensor,
v: Tensor,
g: Tensor,
beta: Tensor,
state: Tensor,
) -> Tensor:
B, T = q.shape[:2]
Hv, Dv = v.shape[-2:]
return torch.empty(B, T, Hv, Dv, dtype=q.dtype, device=q.device)


# C shim mapping for AOTInductor code generation.
# The op mutates state in-place and returns one tensor (y). AOTInductor's
# auto_functionalized wrapper passes 6 input handles + 1 output pointer.
metal_gated_delta_rule_c_shim = {
torch.ops.metal.gated_delta_rule.default: [
"AOTITorchError aoti_torch_mps_gated_delta_rule("
"AtenTensorHandle Q, AtenTensorHandle K, AtenTensorHandle V, "
"AtenTensorHandle G, AtenTensorHandle Beta, AtenTensorHandle StateIn, "
"AtenTensorHandle* retY)"
],
}
115 changes: 115 additions & 0 deletions backends/apple/metal/ops/gather_qmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
metal::gather_qmv custom op for MoE expert-indexed quantized matmul.

Performs y[i] = W[expert_idx[i]] @ x[i] with INT4 quantized expert weights.
The Metal fallback kernel is in runtime/ops/op_gather_qmv.mm.
"""

import torch
from torch import Tensor


@torch.library.custom_op("metal::gather_qmv", mutates_args=())
def gather_qmv(
x: Tensor, # [P, K] — activations (P = num token-expert pairs)
w: Tensor, # [E, N, K_packed] — packed INT4 expert weights
scales: Tensor, # [E, N, K/gs] — per-group scales
biases: Tensor, # [E, N, K/gs] — per-group biases
expert_indices: Tensor, # [P] — expert index per pair
group_size: int,
) -> Tensor:
"""Reference implementation for tracing and CPU testing."""
P, K = x.shape
E, N, K_packed = w.shape

y = torch.zeros(P, N, dtype=x.dtype, device=x.device)
for i in range(P):
eidx = expert_indices[i].item()
w_e = w[eidx] # [N, K_packed]
s_e = scales[eidx] # [N, K/gs]
b_e = biases[eidx] # [N, K/gs]

# Dequantize: unpack INT4, apply affine dequant
w_unpacked = _dequantize_int4_affine(w_e, s_e, b_e, K, group_size)
y[i] = w_unpacked @ x[i]

return y


def _quantize_int4_affine(
w: Tensor, group_size: int
) -> tuple[Tensor, Tensor, Tensor]:
"""Quantize float weights to packed INT4 using MLX affine format.

Args:
w: [..., K] float weight tensor (last dim is quantized).
group_size: Number of elements per quantization group.

Returns:
(packed, scales, biases) where:
- packed: [..., K//2] uint8, two INT4 values per byte.
- scales: [..., K//group_size] per-group scales.
- biases: [..., K//group_size] per-group biases (zero points).

The affine mapping is: dequantized = raw_uint4 * scale + bias,
where raw_uint4 is in [0, 15].
"""
*leading, K = w.shape
w_groups = w.reshape(*leading, K // group_size, group_size)
g_min = w_groups.amin(dim=-1)
g_max = w_groups.amax(dim=-1)
scales = ((g_max - g_min) / 15.0).clamp(min=1e-8)
biases = g_min
w_int = (
(w_groups - biases.unsqueeze(-1)) / scales.unsqueeze(-1)
).round().clamp(0, 15).to(torch.uint8).reshape(*leading, K)
packed = w_int[..., 0::2] | (w_int[..., 1::2] << 4)
return packed, scales, biases


def _dequantize_int4_affine(
w_packed: Tensor, scales: Tensor, biases: Tensor, K: int, group_size: int
) -> Tensor:
"""Dequantize packed INT4 weights using MLX affine format."""
N = w_packed.shape[0]
w_bytes = w_packed.to(torch.int16)
low = w_bytes & 0x0F
high = (w_bytes >> 4) & 0x0F
w_int = torch.stack([low, high], dim=-1).reshape(N, K).float()

scales_expanded = scales.float().repeat_interleave(group_size, dim=-1)[:, :K]
biases_expanded = biases.float().repeat_interleave(group_size, dim=-1)[:, :K]

return (w_int * scales_expanded + biases_expanded).to(scales.dtype)


@torch.library.register_fake("metal::gather_qmv")
def gather_qmv_fake(
x: Tensor,
w: Tensor,
scales: Tensor,
biases: Tensor,
expert_indices: Tensor,
group_size: int,
) -> Tensor:
P = x.shape[0]
N = w.shape[1]
return torch.empty(P, N, dtype=x.dtype, device=x.device)


# C shim mapping for AOTInductor code generation.
# Maps the torch op to the C function name that the generated wrapper calls.
metal_gather_qmv_c_shim = {
torch.ops.metal.gather_qmv.default: [
"AOTITorchError aoti_torch_mps_gather_qmv("
"AtenTensorHandle X, AtenTensorHandle W, AtenTensorHandle S, "
"AtenTensorHandle Z, AtenTensorHandle ExpertIndices, "
"int64_t group_size, AtenTensorHandle* ret)"
],
}
Loading
Loading