Skip to content

Commit d408a10

Browse files
Metal backend: Add gated delta rule kernel for linear attention (#18878)
Adds Metal kernel for the gated delta rule recurrence used by Qwen 3.5 MoE's GatedDeltaNet linear attention layers. Ported from the MLX delegate PR (#18785) Metal shader. The kernel processes the full sequence sequentially within a single GPU dispatch, keeping recurrent state in per-thread registers. Grid: [32, Dv, B*Hv], Threadgroup: [32, 4, 1]. Each simdgroup of 32 threads handles Dk/32 elements of the key dimension with SIMD reduction for dot products. The op mutates the recurrent state buffer in-place (mutates_args). Instantiated for both real model (Dk=128, Dv=128, Hk=32, Hv=32) and tiny test (Dk=64, Dv=64, Hk=4, Hv=4) dimensions. Includes: Metal shader + C++ host dispatch, Python custom op definition (metal::gated_delta_rule) with reference CPU impl and Meta impl, C shim dict, fallback kernel registration, CMakeLists entry, and test module.
1 parent 2fce946 commit d408a10

5 files changed

Lines changed: 475 additions & 0 deletions

File tree

backends/apple/metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(_aoti_metal_sources
4646
runtime/ops/op_bmm.mm
4747
runtime/ops/op_convolution.mm
4848
runtime/ops/op_gather_qmv.mm
49+
runtime/ops/op_gated_delta_rule.mm
4950
runtime/ops/op_linear_4bit.mm
5051
runtime/ops/op_mm.mm
5152
runtime/ops/op_sdpa.mm

backends/apple/metal/metal_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3838
"torchao::_linear_fp_act_4bit_weight": None,
3939
"at::_ops::topk::call": None,
4040
"metal::gather_qmv": None,
41+
"metal::gated_delta_rule": None,
4142
}
4243

4344
@classmethod
@@ -88,6 +89,15 @@ def get_aoti_compile_options(
8889
except ImportError:
8990
pass
9091

92+
try:
93+
from executorch.backends.apple.metal.ops.gated_delta_rule import (
94+
metal_gated_delta_rule_c_shim,
95+
)
96+
97+
custom_c_shims.update(metal_gated_delta_rule_c_shim)
98+
except ImportError:
99+
pass
100+
91101
inductor_configs["aot_inductor.custom_ops_to_c_shims"] = custom_c_shims
92102

93103
return inductor_configs
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
metal::gated_delta_rule custom op for linear attention recurrence.
9+
10+
Performs the gated delta rule recurrence over T time steps, mutating
11+
the recurrent state in-place. The Metal fallback kernel is in
12+
runtime/ops/op_gated_delta_rule.mm.
13+
"""
14+
15+
import torch
16+
from torch import Tensor
17+
18+
19+
@torch.library.custom_op("metal::gated_delta_rule", mutates_args=("state",))
20+
def gated_delta_rule(
21+
q: Tensor, # [B, T, Hk, Dk]
22+
k: Tensor, # [B, T, Hk, Dk]
23+
v: Tensor, # [B, T, Hv, Dv]
24+
g: Tensor, # [B, T, Hv] — decay gate (already exp'd)
25+
beta: Tensor, # [B, T, Hv] — update gate
26+
state: Tensor, # [B, Hv, Dv, Dk] — recurrent state (MUTATED)
27+
) -> Tensor:
28+
"""Reference implementation: sequential recurrence over T."""
29+
B, T_len, Hk, Dk = q.shape
30+
Hv, Dv = v.shape[-2:]
31+
32+
s = state.clone().float()
33+
ys = []
34+
35+
assert Hv % Hk == 0, f"Hv ({Hv}) must be divisible by Hk ({Hk})"
36+
hk_repeat = Hv // Hk
37+
38+
for t in range(T_len):
39+
q_t = q[:, t].float() # [B, Hk, Dk]
40+
k_t = k[:, t].float() # [B, Hk, Dk]
41+
v_t = v[:, t].float() # [B, Hv, Dv]
42+
g_t = g[:, t].float() # [B, Hv]
43+
beta_t = beta[:, t].float() # [B, Hv]
44+
45+
# Expand keys to match value heads (GQA: Hk -> Hv)
46+
if hk_repeat > 1:
47+
q_t = q_t.repeat_interleave(hk_repeat, dim=1) # [B, Hv, Dk]
48+
k_t = k_t.repeat_interleave(hk_repeat, dim=1) # [B, Hv, Dk]
49+
50+
# Decay
51+
s = s * g_t[:, :, None, None]
52+
53+
# Project state by key
54+
kv_mem = (s * k_t[:, :, None, :]).sum(dim=-1) # [B, Hv, Dv]
55+
56+
# Delta rule update
57+
delta = (v_t - kv_mem) * beta_t[:, :, None] # [B, Hv, Dv]
58+
s = s + k_t[:, :, None, :] * delta[:, :, :, None] # [B, Hv, Dv, Dk]
59+
60+
# Read from state
61+
y_t = (s * q_t[:, :, None, :]).sum(dim=-1) # [B, Hv, Dv]
62+
ys.append(y_t)
63+
64+
state.copy_(s.to(state.dtype))
65+
return torch.stack(ys, dim=1).to(q.dtype)
66+
67+
68+
@torch.library.register_fake("metal::gated_delta_rule")
69+
def gated_delta_rule_fake(
70+
q: Tensor,
71+
k: Tensor,
72+
v: Tensor,
73+
g: Tensor,
74+
beta: Tensor,
75+
state: Tensor,
76+
) -> Tensor:
77+
B, T = q.shape[:2]
78+
Hv, Dv = v.shape[-2:]
79+
return torch.empty(B, T, Hv, Dv, dtype=q.dtype, device=q.device)
80+
81+
82+
# C shim mapping for AOTInductor code generation.
83+
# The op mutates state in-place and returns one tensor (y). AOTInductor's
84+
# auto_functionalized wrapper passes 6 input handles + 1 output pointer.
85+
metal_gated_delta_rule_c_shim = {
86+
torch.ops.metal.gated_delta_rule.default: [
87+
"AOTITorchError aoti_torch_mps_gated_delta_rule("
88+
"AtenTensorHandle Q, AtenTensorHandle K, AtenTensorHandle V, "
89+
"AtenTensorHandle G, AtenTensorHandle Beta, AtenTensorHandle StateIn, "
90+
"AtenTensorHandle* retY)"
91+
],
92+
}

0 commit comments

Comments
 (0)