Skip to content

Commit ce9a7f5

Browse files
authored
[MLX] Add MLX support to the Qwen3.5 C++ runner (pytorch#20364)
Adds MLX support to the C++ Qwen3.5 runner. See updated README.md for instructions.
1 parent 60b1351 commit ce9a7f5

8 files changed

Lines changed: 320 additions & 30 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,29 @@ jobs:
161161
fi
162162
echo "::endgroup::"
163163
164+
echo "::group::Verify chunked == unchunked prefill"
165+
QWEN_TINY_PTE=/tmp/qwen35_moe_mlx_tiny/model.pte \
166+
${CONDA_RUN} python -m pytest \
167+
examples/models/qwen3_5_moe/test_chunked_prefill.py -v
168+
echo "::endgroup::"
169+
170+
echo "::group::Build Qwen 3.5 MoE MLX C++ runner"
171+
# Validates the MLX C++ runner build wiring (compile + link + metallib).
172+
# The tiny model has no compatible tokenizer (vocab 256, random weights),
173+
# so we don't run C++ inference here — only confirm it builds.
174+
${CONDA_RUN} make qwen3_5_moe-mlx
175+
RUNNER=cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner
176+
if [ ! -x "$RUNNER" ]; then
177+
echo "Failed: runner not found at $RUNNER"
178+
exit 1
179+
fi
180+
if [ ! -f "$(dirname "$RUNNER")/mlx.metallib" ]; then
181+
echo "Failed: mlx.metallib not copied next to runner"
182+
exit 1
183+
fi
184+
echo "Success: built $RUNNER"
185+
echo "::endgroup::"
186+
164187
backend-tester:
165188
needs: run-decision
166189
if: |

Makefile

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
#
9292
# ==============================================================================
9393

94-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help
94+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal qwen3_5_moe-mlx clean help
9595

9696
help:
9797
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -131,6 +131,7 @@ help:
131131
@echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend"
132132
@echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend"
133133
@echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend"
134+
@echo " qwen3_5_moe-mlx - Build Qwen3.5 MoE runner with MLX backend"
134135
@echo " clean - Clean build artifacts"
135136

136137
voxtral-cuda:
@@ -467,6 +468,15 @@ qwen3_5_moe-metal:
467468
@echo "✓ Build complete!"
468469
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"
469470

471+
qwen3_5_moe-mlx:
472+
@echo "==> Building and installing ExecuTorch with MLX..."
473+
cmake --workflow --preset mlx-release
474+
@echo "==> Building Qwen3.5 MoE runner with MLX..."
475+
cd examples/models/qwen3_5_moe && cmake --workflow --preset qwen3-5-moe-mlx
476+
@echo ""
477+
@echo "✓ Build complete!"
478+
@echo " Binary: cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner"
479+
470480
clean:
471481
rm -rf cmake-out \
472482
extension/llm/tokenizers/build \

examples/models/qwen3_5_moe/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,14 @@ elseif(EXECUTORCH_BUILD_CUDA)
5454
list(APPEND link_libraries aoti_cuda_backend)
5555
executorch_target_link_options_shared_lib(aoti_cuda_backend)
5656
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
57+
elseif(TARGET mlxdelegate)
58+
list(APPEND link_libraries mlxdelegate mlx)
59+
executorch_target_link_options_shared_lib(mlxdelegate)
60+
add_compile_definitions(EXECUTORCH_BUILD_MLX)
5761
else()
5862
message(
59-
FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_METAL=ON"
63+
FATAL_ERROR
64+
"Set EXECUTORCH_BUILD_CUDA=ON, EXECUTORCH_BUILD_METAL=ON, or EXECUTORCH_BUILD_MLX=ON"
6065
)
6166
endif()
6267

@@ -82,6 +87,10 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
8287
target_link_options(qwen3_5_moe_worker PRIVATE "LINKER:-s")
8388
endif()
8489

90+
if(TARGET mlxdelegate)
91+
executorch_target_copy_mlx_metallib(qwen3_5_moe_runner)
92+
endif()
93+
8594
if(EXECUTORCH_BUILD_CUDA)
8695
enable_testing()
8796
add_executable(

examples/models/qwen3_5_moe/CMakePresets.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@
3636
"type": "equals",
3737
"rhs": "Darwin"
3838
}
39+
},
40+
{
41+
"name": "qwen3-5-moe-mlx",
42+
"displayName": "Qwen3.5 MoE runner (MLX)",
43+
"inherits": ["qwen3-5-moe-base"],
44+
"cacheVariables": {
45+
"EXECUTORCH_BUILD_MLX": "ON"
46+
},
47+
"condition": {
48+
"type": "equals",
49+
"lhs": "${hostSystemName}",
50+
"rhs": "Darwin"
51+
}
3952
}
4053
],
4154
"buildPresets": [
@@ -54,6 +67,12 @@
5467
"displayName": "Build Qwen3.5 MoE runner and worker (Metal)",
5568
"configurePreset": "qwen3-5-moe-metal",
5669
"targets": ["qwen3_5_moe_runner", "qwen3_5_moe_worker"]
70+
},
71+
{
72+
"name": "qwen3-5-moe-mlx",
73+
"displayName": "Build Qwen3.5 MoE runner (MLX)",
74+
"configurePreset": "qwen3-5-moe-mlx",
75+
"targets": ["qwen3_5_moe_runner"]
5776
}
5877
],
5978
"workflowPresets": [
@@ -84,6 +103,20 @@
84103
"name": "qwen3-5-moe-metal"
85104
}
86105
]
106+
},
107+
{
108+
"name": "qwen3-5-moe-mlx",
109+
"displayName": "Configure and build Qwen3.5 MoE runner (MLX)",
110+
"steps": [
111+
{
112+
"type": "configure",
113+
"name": "qwen3-5-moe-mlx"
114+
},
115+
{
116+
"type": "build",
117+
"name": "qwen3-5-moe-mlx"
118+
}
119+
]
87120
}
88121
]
89122
}

examples/models/qwen3_5_moe/README.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,38 @@ python export.py \
261261
| `--qembedding` | (none) | Embedding quantization: `8w` |
262262
| `--tiny-test` | off | Build tiny model with random weights for CI testing |
263263

264-
### Run (MLX)
264+
### Build (MLX)
265+
266+
Like the CUDA/Metal builds, the `make` target builds ExecuTorch core with the
267+
MLX backend and the runner binary. Requires Apple Silicon (Darwin).
268+
269+
```bash
270+
make qwen3_5_moe-mlx
271+
```
272+
273+
This builds ExecuTorch with MLX support, then the runner binary at
274+
`cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner` (with `mlx.metallib`
275+
copied next to it). Unlike CUDA, the MLX `.pte` is self-contained — no `.ptd`
276+
data file is produced or needed.
277+
278+
### Run (MLX, C++ runner)
279+
280+
The C++ runner requires a local HuggingFace `tokenizer.json` (the MLX `.pte` and
281+
a `tokenizer.json`; no `--data_path`):
282+
283+
```bash
284+
cmake-out/examples/models/qwen3_5_moe/qwen3_5_moe_runner \
285+
--model_path ./qwen35_moe_mlx/model.pte \
286+
--tokenizer_path ~/models/Qwen3.5-35B-A3B/tokenizer.json \
287+
--prompt "What is the capital of France?" \
288+
--max_new_tokens 50
289+
```
290+
291+
The MLX export emits a single dynamic-seq `forward` method; the runner loads and
292+
calls it for both prefill and decode (sampling on host), matching the Python
293+
runner. See the [Run](#run) section above for the full flag list.
294+
295+
### Run (MLX, Python)
265296

266297
```bash
267298
python -m executorch.examples.models.qwen3_5_moe.run \

examples/models/qwen3_5_moe/export.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,10 +768,16 @@ def _export_mlx(model, config, args):
768768
gc.collect()
769769

770770
print("Lowering to ExecuTorch with MLX backend...")
771+
# Largest prefill chunk the runner may submit in one forward call. The MLX
772+
# runner chunks long prompts to cap peak memory; bound it by the compiled
773+
# dynamic max (max_seq_len - 1) so a chunk can never exceed what `forward`
774+
# was compiled for.
775+
max_prefill_chunk = min(1024, config.max_seq_len - 1)
771776
metadata = {
772777
"get_max_seq_len": config.max_seq_len,
773778
"get_vocab_size": config.vocab_size,
774779
"get_n_layers": config.num_hidden_layers,
780+
"get_max_prefill_chunk": max_prefill_chunk,
775781
"use_kv_cache": True,
776782
"use_sdpa_with_kv_cache": False,
777783
"enable_dynamic_shape": True,

examples/models/qwen3_5_moe/qwen35_moe_engine.cpp

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <cmath>
2020
#include <cstring>
2121

22+
#include <algorithm>
23+
2224
#ifdef EXECUTORCH_BUILD_CUDA
2325
#include <cuda_runtime.h>
2426
#include <executorch/backends/cuda/runtime/cuda_mutable_state.h>
@@ -39,6 +41,22 @@ using SizesType = executorch::aten::SizesType;
3941

4042
namespace {
4143

44+
#ifdef EXECUTORCH_BUILD_MLX
45+
// The MLX export emits a single dynamic-seq `forward` method that handles both
46+
// prefill (T>=2) and decode (T=1). Mirror gemma4_31b's MLX runner, which loads
47+
// and calls `forward` for both phases.
48+
constexpr const char* kPrefillMethod = "forward";
49+
constexpr const char* kDecodeMethod = "forward";
50+
#else
51+
// CUDA/Metal exports emit two separate methods.
52+
constexpr const char* kPrefillMethod = "prefill";
53+
constexpr const char* kDecodeMethod = "decode";
54+
#endif
55+
56+
// Constant method exported by the MLX .pte giving the largest prefill chunk the
57+
// `forward` method was compiled for. Read into the metadata map in create().
58+
constexpr const char* kMaxPrefillChunk = "get_max_prefill_chunk";
59+
4260
Result<uint64_t> read_sampled_token(
4361
const executorch::aten::Tensor& output,
4462
float temperature) {
@@ -98,8 +116,10 @@ Result<std::unique_ptr<Module>> build_qwen_module(
98116
}
99117
#endif
100118

101-
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("prefill"));
102-
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("decode"));
119+
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kPrefillMethod));
120+
if (std::string(kDecodeMethod) != std::string(kPrefillMethod)) {
121+
ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kDecodeMethod));
122+
}
103123
return module;
104124
}
105125

@@ -240,34 +260,63 @@ class Qwen35MoESession : public LLMSession {
240260
}
241261

242262
stop_.store(false, std::memory_order_relaxed);
243-
std::vector<int64_t> token_data(tokens.begin(), tokens.end());
244-
std::vector<int64_t> pos_data(T);
245-
for (int64_t i = 0; i < T; ++i) {
246-
pos_data[i] = pos_ + i;
263+
264+
// On MLX, run prefill in fixed-size chunks (caps peak memory and the
265+
// compiled prefill shape). Other backends prefill the whole prompt in one
266+
// pass. Only the final chunk's sampled token is kept; the recurrence/KV
267+
// state from earlier chunks persists via pos_ advancement.
268+
#ifdef EXECUTORCH_BUILD_MLX
269+
// Chunk size: default to the compiled max (kMaxSeqLen - 1), overridden by
270+
// the exported get_max_prefill_chunk constant when present (mirrors
271+
// gemma4_31b). Falls back to T (single pass) if no metadata is available at
272+
// all.
273+
int64_t chunk_size = T;
274+
if (auto it = metadata_.find(kMaxSeqLen);
275+
it != metadata_.end() && it->second > 1) {
276+
chunk_size = it->second - 1;
247277
}
248-
auto tokens_tensor = from_blob(
249-
token_data.data(),
250-
{1, static_cast<SizesType>(T)},
251-
executorch::aten::ScalarType::Long);
252-
auto pos_tensor = from_blob(
253-
pos_data.data(),
254-
{static_cast<SizesType>(T)},
255-
executorch::aten::ScalarType::Long);
256-
257-
const char* method = (T >= 2) ? "prefill" : "decode";
258-
std::vector<EValue> inputs;
259-
inputs.push_back(tokens_tensor);
260-
inputs.push_back(pos_tensor);
278+
if (auto it = metadata_.find(kMaxPrefillChunk);
279+
it != metadata_.end() && it->second > 0) {
280+
chunk_size = it->second;
281+
}
282+
#else
283+
const int64_t chunk_size = T;
284+
#endif
285+
286+
uint64_t sampled_token = 0;
287+
for (int64_t off = 0; off < T; off += chunk_size) {
288+
const int64_t len = std::min(chunk_size, T - off);
289+
std::vector<int64_t> token_data(
290+
tokens.begin() + off, tokens.begin() + off + len);
291+
std::vector<int64_t> pos_data(len);
292+
for (int64_t i = 0; i < len; ++i) {
293+
pos_data[i] = pos_ + i;
294+
}
295+
auto tokens_tensor = from_blob(
296+
token_data.data(),
297+
{1, static_cast<SizesType>(len)},
298+
executorch::aten::ScalarType::Long);
299+
auto pos_tensor = from_blob(
300+
pos_data.data(),
301+
{static_cast<SizesType>(len)},
302+
executorch::aten::ScalarType::Long);
303+
304+
const char* method = (len >= 2) ? kPrefillMethod : kDecodeMethod;
305+
std::vector<EValue> inputs;
306+
inputs.push_back(tokens_tensor);
307+
inputs.push_back(pos_tensor);
261308
#ifdef EXECUTORCH_BUILD_CUDA
262-
set_temp(first_token_temp);
263-
inputs.push_back(EValue(temp_tensor_));
309+
set_temp(first_token_temp);
310+
inputs.push_back(EValue(temp_tensor_));
264311
#endif
265-
auto sampled =
266-
run_locked(method, inputs, first_token_temp, /*sync_after=*/true);
267-
ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
268-
pending_ = sampled.get();
312+
auto sampled =
313+
run_locked(method, inputs, first_token_temp, /*sync_after=*/true);
314+
ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
315+
sampled_token = sampled.get();
316+
pos_ += len;
317+
}
318+
pending_ = sampled_token;
269319
prev_decode_token_.reset();
270-
pos_ += T;
271320
return Error::Ok;
272321
}
273322

@@ -334,7 +383,7 @@ class Qwen35MoESession : public LLMSession {
334383
inputs.push_back(EValue(temp_tensor_));
335384
#endif
336385
auto sampled =
337-
run_locked("decode", inputs, temperature_, /*sync_after=*/false);
386+
run_locked(kDecodeMethod, inputs, temperature_, /*sync_after=*/false);
338387
ET_CHECK_OK_OR_RETURN_ERROR(sampled.error());
339388
pending_ = sampled.get();
340389
prev_decode_token_ = token;
@@ -457,6 +506,14 @@ Result<std::unique_ptr<Qwen35MoEEngine>> Qwen35MoEEngine::create(
457506
ET_LOG(Error, "Qwen35MoEEngine: failed to read metadata");
458507
return metadata_result.error();
459508
}
509+
#ifdef EXECUTORCH_BUILD_MLX
510+
// Surface the compiled max prefill chunk (a constant method get_llm_metadata
511+
// doesn't harvest) into the metadata map so the session can chunk long
512+
// prompts within the shape `forward` was compiled for.
513+
if (auto mpc = meta_module->get(kMaxPrefillChunk); mpc.ok()) {
514+
metadata_result.get()[kMaxPrefillChunk] = mpc->toScalar().to<int64_t>();
515+
}
516+
#endif
460517
auto eos_ids = get_eos_ids(tokenizer.get(), meta_module.get());
461518
// This export's metadata doesn't carry the chat-turn EOS (config.json has no
462519
// eos_token_id and the .pte exports no get_eos_ids method), so get_eos_ids()

0 commit comments

Comments
 (0)