1919#include < cmath>
2020#include < cstring>
2121
22+ #include < algorithm>
23+
2224#ifdef EXECUTORCH_BUILD_CUDA
2325#include < cuda_runtime.h>
2426#include < executorch/backends/cuda/runtime/cuda_mutable_state.h>
@@ -39,6 +41,22 @@ using SizesType = executorch::aten::SizesType;
3941
4042namespace {
4143
44+ #ifdef EXECUTORCH_BUILD_MLX
45+ // The MLX export emits a single dynamic-seq `forward` method that handles both
46+ // prefill (T>=2) and decode (T=1). Mirror gemma4_31b's MLX runner, which loads
47+ // and calls `forward` for both phases.
48+ constexpr const char * kPrefillMethod = " forward" ;
49+ constexpr const char * kDecodeMethod = " forward" ;
50+ #else
51+ // CUDA/Metal exports emit two separate methods.
52+ constexpr const char * kPrefillMethod = " prefill" ;
53+ constexpr const char * kDecodeMethod = " decode" ;
54+ #endif
55+
56+ // Constant method exported by the MLX .pte giving the largest prefill chunk the
57+ // `forward` method was compiled for. Read into the metadata map in create().
58+ constexpr const char * kMaxPrefillChunk = " get_max_prefill_chunk" ;
59+
4260Result<uint64_t > read_sampled_token (
4361 const executorch::aten::Tensor& output,
4462 float temperature) {
@@ -98,8 +116,10 @@ Result<std::unique_ptr<Module>> build_qwen_module(
98116 }
99117#endif
100118
101- ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (" prefill" ));
102- ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (" decode" ));
119+ ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kPrefillMethod ));
120+ if (std::string (kDecodeMethod ) != std::string (kPrefillMethod )) {
121+ ET_CHECK_OK_OR_RETURN_ERROR (module ->load_method (kDecodeMethod ));
122+ }
103123 return module ;
104124}
105125
@@ -240,34 +260,63 @@ class Qwen35MoESession : public LLMSession {
240260 }
241261
242262 stop_.store (false , std::memory_order_relaxed);
243- std::vector<int64_t > token_data (tokens.begin (), tokens.end ());
244- std::vector<int64_t > pos_data (T);
245- for (int64_t i = 0 ; i < T; ++i) {
246- pos_data[i] = pos_ + i;
263+
264+ // On MLX, run prefill in fixed-size chunks (caps peak memory and the
265+ // compiled prefill shape). Other backends prefill the whole prompt in one
266+ // pass. Only the final chunk's sampled token is kept; the recurrence/KV
267+ // state from earlier chunks persists via pos_ advancement.
268+ #ifdef EXECUTORCH_BUILD_MLX
269+ // Chunk size: default to the compiled max (kMaxSeqLen - 1), overridden by
270+ // the exported get_max_prefill_chunk constant when present (mirrors
271+ // gemma4_31b). Falls back to T (single pass) if no metadata is available at
272+ // all.
273+ int64_t chunk_size = T;
274+ if (auto it = metadata_.find (kMaxSeqLen );
275+ it != metadata_.end () && it->second > 1 ) {
276+ chunk_size = it->second - 1 ;
247277 }
248- auto tokens_tensor = from_blob (
249- token_data.data (),
250- {1 , static_cast <SizesType>(T)},
251- executorch::aten::ScalarType::Long);
252- auto pos_tensor = from_blob (
253- pos_data.data (),
254- {static_cast <SizesType>(T)},
255- executorch::aten::ScalarType::Long);
256-
257- const char * method = (T >= 2 ) ? " prefill" : " decode" ;
258- std::vector<EValue> inputs;
259- inputs.push_back (tokens_tensor);
260- inputs.push_back (pos_tensor);
278+ if (auto it = metadata_.find (kMaxPrefillChunk );
279+ it != metadata_.end () && it->second > 0 ) {
280+ chunk_size = it->second ;
281+ }
282+ #else
283+ const int64_t chunk_size = T;
284+ #endif
285+
286+ uint64_t sampled_token = 0 ;
287+ for (int64_t off = 0 ; off < T; off += chunk_size) {
288+ const int64_t len = std::min (chunk_size, T - off);
289+ std::vector<int64_t > token_data (
290+ tokens.begin () + off, tokens.begin () + off + len);
291+ std::vector<int64_t > pos_data (len);
292+ for (int64_t i = 0 ; i < len; ++i) {
293+ pos_data[i] = pos_ + i;
294+ }
295+ auto tokens_tensor = from_blob (
296+ token_data.data (),
297+ {1 , static_cast <SizesType>(len)},
298+ executorch::aten::ScalarType::Long);
299+ auto pos_tensor = from_blob (
300+ pos_data.data (),
301+ {static_cast <SizesType>(len)},
302+ executorch::aten::ScalarType::Long);
303+
304+ const char * method = (len >= 2 ) ? kPrefillMethod : kDecodeMethod ;
305+ std::vector<EValue> inputs;
306+ inputs.push_back (tokens_tensor);
307+ inputs.push_back (pos_tensor);
261308#ifdef EXECUTORCH_BUILD_CUDA
262- set_temp (first_token_temp);
263- inputs.push_back (EValue (temp_tensor_));
309+ set_temp (first_token_temp);
310+ inputs.push_back (EValue (temp_tensor_));
264311#endif
265- auto sampled =
266- run_locked (method, inputs, first_token_temp, /* sync_after=*/ true );
267- ET_CHECK_OK_OR_RETURN_ERROR (sampled.error ());
268- pending_ = sampled.get ();
312+ auto sampled =
313+ run_locked (method, inputs, first_token_temp, /* sync_after=*/ true );
314+ ET_CHECK_OK_OR_RETURN_ERROR (sampled.error ());
315+ sampled_token = sampled.get ();
316+ pos_ += len;
317+ }
318+ pending_ = sampled_token;
269319 prev_decode_token_.reset ();
270- pos_ += T;
271320 return Error::Ok;
272321 }
273322
@@ -334,7 +383,7 @@ class Qwen35MoESession : public LLMSession {
334383 inputs.push_back (EValue (temp_tensor_));
335384#endif
336385 auto sampled =
337- run_locked (" decode " , inputs, temperature_, /* sync_after=*/ false );
386+ run_locked (kDecodeMethod , inputs, temperature_, /* sync_after=*/ false );
338387 ET_CHECK_OK_OR_RETURN_ERROR (sampled.error ());
339388 pending_ = sampled.get ();
340389 prev_decode_token_ = token;
@@ -457,6 +506,14 @@ Result<std::unique_ptr<Qwen35MoEEngine>> Qwen35MoEEngine::create(
457506 ET_LOG (Error, " Qwen35MoEEngine: failed to read metadata" );
458507 return metadata_result.error ();
459508 }
509+ #ifdef EXECUTORCH_BUILD_MLX
510+ // Surface the compiled max prefill chunk (a constant method get_llm_metadata
511+ // doesn't harvest) into the metadata map so the session can chunk long
512+ // prompts within the shape `forward` was compiled for.
513+ if (auto mpc = meta_module->get (kMaxPrefillChunk ); mpc.ok ()) {
514+ metadata_result.get ()[kMaxPrefillChunk ] = mpc->toScalar ().to <int64_t >();
515+ }
516+ #endif
460517 auto eos_ids = get_eos_ids (tokenizer.get (), meta_module.get ());
461518 // This export's metadata doesn't carry the chat-turn EOS (config.json has no
462519 // eos_token_id and the .pte exports no get_eos_ids method), so get_eos_ids()
0 commit comments