Skip to content

Commit 9600f63

Browse files
Qwen 3.5 MoE: Add Metal source transformations (#18879)
Adds metal_source_transformations.py with module replacements for Metal: - FusedMoEExperts -> MetalMoEExperts (two metal::gather_qmv calls with SiLU gating, replacing torch.ops.triton.fused_moe) - GatedDeltaNet -> metal::gated_delta_rule custom op (replaces both the T=1 native path and T>1 Triton kernel) - FullAttention -> removes turboquant codepath, keeps standard SDPA - SparseMoE -> removes .float() cast on expert_weights Also includes quantize_experts_metal() which quantizes expert weights to MLX affine INT4 format (unsigned uint4 with scale + bias per group), compatible with the Metal gather_qmv kernel.
1 parent d408a10 commit 9600f63

1 file changed

Lines changed: 336 additions & 0 deletions

File tree

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
Metal source transformations for Qwen 3.5 MoE.
10+
11+
Replaces Triton-dependent modules (FusedMoEExperts, GatedDeltaNet) with
12+
pure-PyTorch + Metal custom op equivalents that can be exported and lowered
13+
to the Metal backend via AOTInductor.
14+
"""
15+
16+
import logging
17+
import types
18+
19+
import torch
20+
import torch.nn as nn
21+
import torch.nn.functional as F
22+
23+
from executorch.examples.models.qwen3_5_moe.model import (
24+
FullAttention,
25+
FusedMoEExperts,
26+
GatedDeltaNet,
27+
SparseMoE,
28+
)
29+
30+
logger = logging.getLogger(__name__)
31+
32+
33+
# ---------------------------------------------------------------------------
34+
# MetalMoEExperts: replaces FusedMoEExperts
35+
# ---------------------------------------------------------------------------
36+
37+
38+
class MetalMoEExperts(nn.Module):
39+
"""MoE experts using metal::gather_qmv for expert-indexed quantized matmul.
40+
41+
Decomposes the fused MoE into two gather_qmv calls (gate+up, down) with
42+
SiLU gating in between. Expert weights are in MLX affine INT4 format.
43+
"""
44+
45+
def __init__(self, num_experts, intermediate_size, hidden_size, group_size=32):
46+
super().__init__()
47+
self.num_experts = num_experts
48+
self.intermediate_size = intermediate_size
49+
self.hidden_size = hidden_size
50+
self.group_size = group_size
51+
52+
def forward(self, x, expert_weights, expert_indices, top_k):
53+
P = x.shape[0]
54+
# Flatten expert pairs: [P, top_k] -> [P*top_k]
55+
indices_flat = expert_indices.reshape(-1).to(torch.int32)
56+
x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(P * top_k, -1)
57+
58+
# GEMM1: gate+up projection [P*top_k, K] @ [E, 2*inter, K].T -> [P*top_k, 2*inter]
59+
gate_up = torch.ops.metal.gather_qmv(
60+
x_expanded, self.w1, self.s1, self.b1, indices_flat, self.group_size
61+
)
62+
gate = gate_up[..., : self.intermediate_size]
63+
up = gate_up[..., self.intermediate_size :]
64+
activated = F.silu(gate) * up
65+
66+
# GEMM2: down projection [P*top_k, inter] @ [E, K, inter].T -> [P*top_k, K]
67+
down = torch.ops.metal.gather_qmv(
68+
activated, self.w2, self.s2, self.b2, indices_flat, self.group_size
69+
)
70+
71+
# Weighted sum over top_k experts
72+
down = down.view(P, top_k, -1)
73+
return (down * expert_weights.unsqueeze(-1)).sum(dim=1)
74+
75+
76+
# ---------------------------------------------------------------------------
77+
# GatedDeltaNet replacement forward
78+
# ---------------------------------------------------------------------------
79+
80+
81+
def _metal_gated_delta_net_forward(self, x, input_pos):
82+
"""Replacement forward for GatedDeltaNet using metal::gated_delta_rule.
83+
84+
Same pre/post-processing as the original, but replaces both the T=1
85+
native path and the T>1 Triton kernel with a single custom op call
86+
that works for all T values.
87+
"""
88+
B, T, _ = x.size()
89+
90+
# Reset state at position 0
91+
reset = (input_pos[0] == 0).to(self.conv_state.dtype)
92+
keep = 1.0 - reset
93+
self.conv_state[:B].mul_(keep)
94+
self.recurrent_state[:B].mul_(keep)
95+
96+
# Fused projection: split into qkv, z, b, a
97+
proj = self.in_proj(x)
98+
cd = self.conv_dim
99+
vd = self.value_dim
100+
nh = self.num_v_heads
101+
mixed_qkv = proj[..., :cd]
102+
z = proj[..., cd : cd + vd].reshape(B, T, self.num_v_heads, self.head_v_dim)
103+
b = proj[..., cd + vd : cd + vd + nh]
104+
a = proj[..., cd + vd + nh :]
105+
106+
# Causal depthwise conv1d with state
107+
qkv_t = mixed_qkv.transpose(1, 2)
108+
conv_input = torch.cat([self.conv_state[:B], qkv_t], dim=-1)
109+
conv_len = conv_input.shape[-1]
110+
self.conv_state[:B].copy_(conv_input[:, :, conv_len - self.conv_kernel_size :])
111+
112+
# Manual depthwise conv1d (avoids conv1d->conv2d decomposition)
113+
w = self.conv1d.weight.squeeze(1).float()
114+
T_conv = conv_input.shape[-1] - self.conv_kernel_size + 1
115+
acc = torch.zeros(
116+
B, conv_input.shape[1], T_conv, dtype=torch.float32, device=conv_input.device
117+
)
118+
for k in range(self.conv_kernel_size):
119+
acc = acc + conv_input[:, :, k : k + T_conv].float() * w[:, k : k + 1]
120+
qkv_conv = F.silu(acc[:, :, -T:]).to(conv_input.dtype).transpose(1, 2)
121+
122+
# Split into Q, K, V
123+
kd = self.key_dim
124+
q = qkv_conv[..., :kd].reshape(B, T, self.num_k_heads, self.head_k_dim)
125+
k = qkv_conv[..., kd : 2 * kd].reshape(B, T, self.num_k_heads, self.head_k_dim)
126+
v = qkv_conv[..., 2 * kd :].reshape(B, T, self.num_v_heads, self.head_v_dim)
127+
128+
# L2-normalize Q and K
129+
q = F.normalize(q, p=2, dim=-1)
130+
k = F.normalize(k, p=2, dim=-1)
131+
132+
# head_repeat for k_heads != v_heads
133+
if self.head_repeat > 1:
134+
q = q.repeat_interleave(self.head_repeat, dim=2)
135+
k = k.repeat_interleave(self.head_repeat, dim=2)
136+
137+
# Mamba-style gating: g = exp(-A * softplus(a + dt_bias))
138+
beta = b.sigmoid()
139+
g = (-self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)).exp()
140+
141+
# Metal custom op: handles both T=1 and T>1
142+
import executorch.backends.apple.metal.ops.gated_delta_rule as _ # noqa: F401
143+
144+
output = torch.ops.metal.gated_delta_rule(
145+
q, k, v, g, beta, self.recurrent_state[:B]
146+
)
147+
148+
# Output: RMSNorm(output) * silu(z)
149+
output = output.reshape(-1, self.head_v_dim)
150+
z = z.reshape(-1, self.head_v_dim)
151+
output = self.norm(output, z)
152+
output = output.reshape(B, T, -1)
153+
154+
return self.out_proj(output)
155+
156+
157+
# ---------------------------------------------------------------------------
158+
# FullAttention: remove turboquant
159+
# ---------------------------------------------------------------------------
160+
161+
162+
def _metal_full_attention_forward(self, x, input_pos):
163+
"""FullAttention forward without turboquant (CUDA-only)."""
164+
B, T, _ = x.size()
165+
dtype = x.dtype
166+
167+
qkv = self.qkv_proj(x)
168+
q_and_gate = qkv[..., : self.q_dim].view(B, T, self.n_heads, self.head_dim * 2)
169+
q = q_and_gate[..., : self.head_dim]
170+
gate = q_and_gate[..., self.head_dim :]
171+
172+
k = qkv[..., self.q_dim : self.q_dim + self.k_dim].view(
173+
B, T, self.n_kv_heads, self.head_dim
174+
)
175+
v = qkv[..., self.q_dim + self.k_dim :].view(B, T, self.n_kv_heads, self.head_dim)
176+
177+
q = self.q_norm(q)
178+
k = self.k_norm(k)
179+
180+
q, k = self.rotary_emb(input_pos, q, k)
181+
182+
q = q.to(dtype).transpose(1, 2)
183+
k = k.to(dtype).transpose(1, 2)
184+
v = v.transpose(1, 2)
185+
186+
attn_mask = (
187+
(self.cache_positions[None, :] <= input_pos[:, None]).unsqueeze(0).unsqueeze(0)
188+
)
189+
190+
# Always use standard SDPA (no turboquant on Metal)
191+
k, v = self.kv_cache.update(input_pos, k, v)
192+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=True)
193+
194+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
195+
196+
gate = gate.reshape(B, T, -1)
197+
y = y * torch.sigmoid(gate)
198+
199+
return self.o_proj(y)
200+
201+
202+
# ---------------------------------------------------------------------------
203+
# Expert weight quantization (MLX affine INT4 format)
204+
# ---------------------------------------------------------------------------
205+
206+
207+
def quantize_experts_metal(model, config, group_size=32):
208+
"""Quantize expert weights to MLX affine INT4 format for metal::gather_qmv.
209+
210+
Produces unsigned INT4 with scale + bias (zero-point) per group:
211+
dequant(w) = w_uint4 * scale + bias
212+
213+
Output layout per expert:
214+
w: [N, K//2] uint8 (two 4-bit values packed per byte)
215+
scales: [N, K//group_size] same dtype as model
216+
biases: [N, K//group_size] same dtype as model
217+
"""
218+
from torchao.quantization.quant_primitives import (
219+
choose_qparams_affine,
220+
MappingType,
221+
quantize_affine,
222+
)
223+
224+
for i, layer in enumerate(model.layers):
225+
experts = layer.mlp.experts
226+
if not isinstance(experts, FusedMoEExperts):
227+
continue
228+
229+
metal_experts = MetalMoEExperts(
230+
experts.num_experts,
231+
experts.intermediate_size,
232+
experts.hidden_size,
233+
group_size,
234+
)
235+
236+
for name in ("w1_weight", "w2_weight"):
237+
w = getattr(experts, name).data.float()
238+
E, N, K = w.shape
239+
block_size = (1, 1, group_size)
240+
241+
scale, zero_point = choose_qparams_affine(
242+
w,
243+
MappingType.ASYMMETRIC,
244+
block_size,
245+
target_dtype=torch.uint8,
246+
quant_min=0,
247+
quant_max=15,
248+
)
249+
250+
int_data = quantize_affine(
251+
w,
252+
block_size,
253+
scale,
254+
zero_point,
255+
output_dtype=torch.uint8,
256+
quant_min=0,
257+
quant_max=15,
258+
)
259+
260+
# Pack two uint4 values per byte: even -> low nibble, odd -> high nibble
261+
low = int_data[:, :, 0::2]
262+
high = int_data[:, :, 1::2]
263+
packed = (low | (high << 4)).to(torch.uint8)
264+
265+
scale = scale.reshape(E, N, -1)
266+
# Compute bias: zero_point contribution -> -zero_point * scale
267+
bias = (-zero_point.reshape(E, N, -1).float() * scale.float()).to(
268+
scale.dtype
269+
)
270+
271+
buf_prefix = "w1" if "w1" in name else "w2"
272+
metal_experts.register_buffer(f"{buf_prefix}", packed)
273+
metal_experts.register_buffer(f"s{buf_prefix[1]}", scale.to(w.dtype))
274+
metal_experts.register_buffer(f"b{buf_prefix[1]}", bias.to(w.dtype))
275+
276+
# Replace in model
277+
parts = f"layers.{i}.mlp.experts".rsplit(".", 1)
278+
parent = model.get_submodule(parts[0])
279+
setattr(parent, parts[1], metal_experts)
280+
print(
281+
f" Quantized experts (Metal INT4) layer {i + 1}/{config.num_hidden_layers}",
282+
end="\r",
283+
)
284+
print()
285+
286+
287+
# ---------------------------------------------------------------------------
288+
# Top-level transformation
289+
# ---------------------------------------------------------------------------
290+
291+
292+
def metal_source_transformations(model, config=None):
293+
"""Replace all Triton-dependent modules with Metal-compatible equivalents.
294+
295+
Transforms:
296+
1. GatedDeltaNet → metal::gated_delta_rule custom op
297+
2. FullAttention → remove turboquant, keep standard SDPA
298+
3. SparseMoE.experts already replaced by quantize_experts_metal()
299+
"""
300+
count_gdn = 0
301+
for _name, module in model.named_modules():
302+
if isinstance(module, GatedDeltaNet):
303+
module.forward = types.MethodType(_metal_gated_delta_net_forward, module)
304+
count_gdn += 1
305+
306+
count_attn = 0
307+
for _name, module in model.named_modules():
308+
if isinstance(module, FullAttention):
309+
module.turboquant = False
310+
module.forward = types.MethodType(_metal_full_attention_forward, module)
311+
count_attn += 1
312+
313+
# Remove .float() cast on expert_weights in SparseMoE
314+
count_moe = 0
315+
for _name, module in model.named_modules():
316+
if isinstance(module, SparseMoE):
317+
318+
def _sparse_moe_forward(self, x):
319+
B, T, C = x.size()
320+
x_flat = x.view(-1, C)
321+
scores = self.gate(x_flat)
322+
expert_weights, expert_indices = torch.topk(scores, self.top_k, dim=-1)
323+
expert_weights = expert_weights.softmax(dim=-1)
324+
routed_out = self.experts(
325+
x_flat, expert_weights, expert_indices, self.top_k
326+
)
327+
shared_out = self.shared_expert(x_flat)
328+
shared_gate = torch.sigmoid(self.shared_expert_gate(x_flat))
329+
return (routed_out + shared_gate * shared_out).view(B, T, C)
330+
331+
module.forward = types.MethodType(_sparse_moe_forward, module)
332+
count_moe += 1
333+
334+
logger.info(f"Replaced {count_gdn} GatedDeltaNet → metal::gated_delta_rule")
335+
logger.info(f"Replaced {count_attn} FullAttention → standard SDPA (no turboquant)")
336+
logger.info(f"Replaced {count_moe} SparseMoE → no .float() cast")

0 commit comments

Comments
 (0)