Skip to content

Commit 226ee36

Browse files
authored
Merge branch 'main' into mlx-docs
2 parents 8b54a32 + 57887ec commit 226ee36

21 files changed

Lines changed: 4265 additions & 85 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- extension/audio/**
1414
- examples/models/parakeet/**
1515
- examples/models/voxtral_realtime/**
16+
- examples/models/qwen3_5_moe/**
1617
workflow_dispatch:
1718

1819
permissions: {}
@@ -63,6 +64,61 @@ jobs:
6364
./cmake-out/backends/mlx/test/multi_thread_test_runner
6465
echo "::endgroup::"
6566
67+
echo "::group::Run gated_delta_rule op tests"
68+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
69+
echo "::endgroup::"
70+
71+
test-mlx-qwen35-moe:
72+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
73+
with:
74+
job-name: test-mlx-qwen35-moe
75+
runner: macos-14-xlarge
76+
python-version: "3.12"
77+
submodules: recursive
78+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
79+
timeout: 90
80+
script: |
81+
set -eux
82+
83+
echo "::group::Install ExecuTorch"
84+
${CONDA_RUN} python install_executorch.py > /dev/null
85+
echo "::endgroup::"
86+
87+
${CONDA_RUN} pip list
88+
89+
echo "::group::Export Qwen 3.5 MoE (tiny model)"
90+
${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.export \
91+
--tiny-test \
92+
--backend mlx \
93+
--qlinear 4w \
94+
--qlinear-group-size 32 \
95+
--output-dir /tmp/qwen35_moe_mlx_tiny
96+
echo "::endgroup::"
97+
98+
echo "::group::Check AsType node count"
99+
ASTYPE_COUNT=$(${CONDA_RUN} python -m executorch.backends.mlx.pte_inspector \
100+
/tmp/qwen35_moe_mlx_tiny/model.pte --mlx-instructions 2>&1 | grep -c "AsTypeNode" || true)
101+
echo "AsType nodes: ${ASTYPE_COUNT}"
102+
if [ "$ASTYPE_COUNT" -gt 23 ]; then
103+
echo "Failed: expected no more than 23 AsType nodes, got ${ASTYPE_COUNT}"
104+
exit 1
105+
fi
106+
echo "::endgroup::"
107+
108+
echo "::group::Run Qwen 3.5 MoE inference"
109+
OUTPUT=$(${CONDA_RUN} python -m executorch.examples.models.qwen3_5_moe.run \
110+
--pte /tmp/qwen35_moe_mlx_tiny/model.pte \
111+
--prompt-len 4 \
112+
--max-new-tokens 5 2>&1)
113+
echo "$OUTPUT"
114+
if echo "$OUTPUT" | grep -q "Generated token ids: \[167, 167, 81, 167, 81\]"; then
115+
echo "Success: Qwen 3.5 MoE MLX export + inference completed with expected output"
116+
else
117+
echo "Failed: unexpected output (expected [167, 167, 81, 167, 81])"
118+
exit 1
119+
fi
120+
echo "::endgroup::"
121+
66122
backend-tester:
67123
strategy:
68124
fail-fast: false

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def get_meandim_decomposition(op) -> tuple:
3535
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
3636

3737

38+
def get_dynamic_meandim_decomposition(op) -> tuple:
39+
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
40+
return (
41+
exir_ops.edge.aten.sum.dim_IntList,
42+
exir_ops.edge.aten.mul.Tensor,
43+
exir_ops.edge.aten.full.default,
44+
exir_ops.edge.aten.reciprocal.default,
45+
exir_ops.edge.aten.expand_copy.default,
46+
)
47+
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
48+
raise NotImplementedError(
49+
"Dynamic mean.dim decomposition is not supported for torch.aten.mean."
50+
)
51+
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
52+
53+
3854
def get_avgpool(op):
3955
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
4056
return exir_ops.edge.aten.avg_pool2d.default
@@ -103,26 +119,39 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
103119
self._tosa_spec, WhyNoPartitionReporter()
104120
)
105121

106-
def call_operator(self, op, args, kwargs, meta):
122+
def call_operator(self, op, args, kwargs, meta, updated=False):
107123
if op not in (
108124
exir_ops.edge.aten.mean.dim,
109125
torch.ops.aten.mean.dim,
110126
exir_ops.edge.aten.mean.default,
111127
torch.ops.aten.mean.default,
112128
) or not self.allowed_to_transform(meta):
113-
return super().call_operator(op, args, kwargs, meta)
129+
return super().call_operator(op, args, kwargs, meta, updated)
114130

115131
x = get_node_arg(args, 0)
116132
input_shape = list(x.data.shape)
117133
output_shape = list(meta["val"].shape)
134+
118135
dims_to_reduce = get_node_arg(args, 1, range(len(input_shape)))
119136
if dims_to_reduce is None:
120137
dims_to_reduce = range(len(input_shape))
138+
121139
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
122-
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
140+
141+
has_symbolic_reduce_dim = any(
142+
isinstance(input_shape[dim], torch.SymInt) for dim in dims_to_reduce
143+
)
144+
if has_symbolic_reduce_dim and get_quantization(x.node.target) is not None:
145+
raise NotImplementedError(
146+
"Quantized mean.dim with symbolic reduced dimensions is not supported"
147+
)
123148

124149
view_op = get_view(op)
125150

151+
if not has_symbolic_reduce_dim:
152+
# for static shapes we should ensure that we only keep non 1 dimensions.
153+
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]
154+
126155
# Reshape to 4D
127156
if len(input_shape) != 4:
128157
new_shape = copy(input_shape)
@@ -140,26 +169,66 @@ def call_operator(self, op, args, kwargs, meta):
140169
x = self._maybe_insert_q_dq_after(x, meta)
141170

142171
# Reduce (h,w) dims by avg pool if possible
143-
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
172+
if not has_symbolic_reduce_dim:
173+
x, dims_to_reduce = self._reduce_by_average_pool(
174+
op, x, dims_to_reduce, meta
175+
)
144176

145177
# Reshape back to 5D if necessary
146178
if len(input_shape) > 4:
147-
original_dims = input_shape[0:-3]
179+
original_dims = input_shape[:-3]
148180
temp_shape = list(x.data.shape)[1:]
149181
temp_shape = original_dims + temp_shape
150182
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
151183

152184
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
153185
x = self._maybe_insert_q_dq_after(x, meta)
154-
# Reduce remaining dims by sum
155-
x = self._reduce_by_sum(op, x, dims_to_reduce, meta)
186+
187+
if has_symbolic_reduce_dim:
188+
x = self._reduce_by_sum_symbolic(op, x, dims_to_reduce, meta)
189+
else:
190+
x = self._reduce_by_sum(op, x, dims_to_reduce, meta)
156191

157192
# Reshape to correct output shape if necessary
158193
if list(x.data.shape) != output_shape:
159194
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)
160195

161196
return x
162197

198+
def _reduce_by_sum_symbolic(self, op, input_node, dims, meta):
199+
input_shape = input_node.data.size()
200+
reduced_shape = [input_shape[dim] for dim in dims]
201+
202+
sum_op, mul_op, full_op, recip_op, expand_op = (
203+
get_dynamic_meandim_decomposition(op)
204+
)
205+
206+
sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)
207+
208+
ones = super().call_operator(
209+
full_op,
210+
([1], 1.0),
211+
{"dtype": meta.data["val"].dtype, "device": input_node.data.device},
212+
meta,
213+
True,
214+
)
215+
expanded_ones = super().call_operator(
216+
expand_op,
217+
(ones, reduced_shape),
218+
{},
219+
meta,
220+
True,
221+
)
222+
counts = super().call_operator(
223+
sum_op,
224+
(expanded_ones, list(range(len(reduced_shape))), True),
225+
{},
226+
meta,
227+
True,
228+
)
229+
recip = super().call_operator(recip_op, (counts,), {}, meta, True)
230+
return super().call_operator(mul_op, (sum, recip), {}, meta, True)
231+
163232
def _reduce_by_sum(self, op, input_node, dims, meta):
164233
if len(dims) == 0:
165234
return input_node
@@ -224,13 +293,9 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
224293
if is_supported:
225294
out = super().call_operator(avgpool_op, args, {}, meta, True)
226295
out = self._maybe_insert_q_dq_after(out, meta)
227-
return (
228-
out,
229-
dims_to_reduce_by_sum,
230-
)
296+
return out, dims_to_reduce_by_sum
231297

232-
else:
233-
return input_node, dims
298+
return input_node, dims
234299

235300
def _maybe_insert_q_dq_after(self, op, meta):
236301
"""If the input node of op is a dequant node, insert a q-dq pair after
@@ -242,20 +307,18 @@ def _maybe_insert_q_dq_after(self, op, meta):
242307
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
243308
)
244309
input_node = op.node.all_input_nodes[0]
245-
if (quant_ops := get_quantization(input_node.target)) is not None:
246-
q_op, dq_op = quant_ops
247-
quant_args = list(input_node.args[1:])
248-
q_args = (op, *quant_args)
249-
out = super().call_operator(
250-
q_op,
251-
q_args,
252-
kwargs={},
253-
meta=meta,
254-
updated=True,
255-
)
256-
dq_args = (out, *quant_args)
257-
return super().call_operator(
258-
dq_op, dq_args, kwargs={}, meta=meta, updated=True
259-
)
260-
else:
310+
if (quant_ops := get_quantization(input_node.target)) is None:
261311
return op
312+
313+
q_op, dq_op = quant_ops
314+
quant_args = list(input_node.args[1:])
315+
q_args = (op, *quant_args)
316+
out = super().call_operator(
317+
q_op,
318+
q_args,
319+
kwargs={},
320+
meta=meta,
321+
updated=True,
322+
)
323+
dq_args = (out, *quant_args)
324+
return super().call_operator(dq_op, dq_args, kwargs={}, meta=meta, updated=True)

backends/mlx/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,14 @@ add_subdirectory(${MLX_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/mlx)
247247
# Op logging option (for debugging) - OFF by default for performance
248248
option(ET_MLX_ENABLE_OP_LOGGING "Enable per-op logging in MLX delegate" OFF)
249249

250+
# Custom kernel execution - OFF by default for security. When enabled,
251+
# MetalKernelNode can execute arbitrary Metal shader code embedded in .pte
252+
# files. Only enable for trusted .pte sources.
253+
option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
254+
"Allow MetalKernelNode to execute custom Metal shaders from .pte files"
255+
ON
256+
)
257+
250258
set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
251259
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
252260
)
@@ -262,6 +270,13 @@ if(ET_MLX_ENABLE_OP_LOGGING)
262270
message(STATUS "MLX delegate op logging ENABLED")
263271
endif()
264272

273+
if(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION)
274+
target_compile_definitions(
275+
mlxdelegate PRIVATE ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
276+
)
277+
message(STATUS "MLX delegate custom kernel execution ENABLED")
278+
endif()
279+
265280
target_include_directories(
266281
mlxdelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime
267282
)

0 commit comments

Comments
 (0)