2525
2626#ifdef EXECUTORCH_BUILD_CUDA
2727#include < cuda_runtime.h>
28+ #else
29+ #include < executorch/extension/llm/sampler/util.h>
2830#endif
2931
3032DEFINE_string (model_path, " " , " Model .pte file path." );
@@ -37,7 +39,7 @@ DEFINE_string(
3739 " Path to file containing prompt text (overrides --prompt)." );
3840DEFINE_double (temperature, 0.8 , " Sampling temperature (0 = greedy)." );
3941DEFINE_int32 (max_new_tokens, 128 , " Maximum tokens to generate." );
40- DEFINE_bool (cuda_graph, false , " Enable CUDA graph for decode method." );
42+ DEFINE_bool (cuda_graph, false , " Enable CUDA graph for decode method. CUDA only. " );
4143
4244namespace llm = ::executorch::extension::llm;
4345using ::executorch::extension::from_blob;
@@ -48,10 +50,18 @@ using ::executorch::runtime::EValue;
4850
4951using SizesType = executorch::aten::SizesType;
5052
51- // Read a sampled token from the model output tensor [B, 1].
52- // The model performs Gumbel-max sampling on-device and returns a single
53- // float token ID. This function copies it from GPU and casts to uint64.
53+ // Convert a model output tensor to the next sampled token id.
54+ //
55+ // On the CUDA build, the model fuses the sampler in (see sampler.py /
56+ // Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
57+ // float tensor; we just copy that scalar back from device.
58+ //
59+ // On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
60+ // of shape [B, T, V] in the model dtype (typically bf16). We sample on
61+ // CPU via the shared `llm::logits_to_token` helper, which accepts a
62+ // temperature (0 = greedy / argmax).
5463static uint64_t read_token (const executorch::aten::Tensor& output) {
64+ #ifdef EXECUTORCH_BUILD_CUDA
5565 const void * ptr = output.const_data_ptr ();
5666
5767 cudaPointerAttributes attrs;
@@ -73,6 +83,14 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
7383 memcpy (&val, ptr, sizeof (float ));
7484 }
7585 return static_cast <uint64_t >(val);
86+ #else
87+ // logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 /
88+ // UInt16 dtypes. Negative temperatures are clamped to 0 (greedy).
89+ const float temp = FLAGS_temperature <= 0.0
90+ ? 0 .0f
91+ : static_cast <float >(FLAGS_temperature);
92+ return static_cast <uint64_t >(llm::logits_to_token (output, temp));
93+ #endif
7694}
7795
7896int main (int argc, char ** argv) {
@@ -133,16 +151,23 @@ int main(int argc, char** argv) {
133151 }
134152 auto metadata = metadata_result.get ();
135153
154+ #ifdef EXECUTORCH_BUILD_CUDA
136155 // Set CUDA graph option if requested (must be before load_method)
137156 if (FLAGS_cuda_graph) {
138157 executorch::runtime::BackendOptions<2 > cuda_opts;
139158 cuda_opts.set_option (" enable_cuda_graph_for_method" , " decode" );
140159 executorch::runtime::set_option (" CudaBackend" , cuda_opts.view ());
141160 printf (" CUDA graph enabled for decode method\n " );
142161 }
162+ #else
163+ if (FLAGS_cuda_graph) {
164+ ET_LOG (Info, " --cuda_graph ignored on non-CUDA build" );
165+ }
166+ #endif
143167
144168 printf (" Loading methods...\n " );
145169
170+ #ifdef EXECUTORCH_BUILD_CUDA
146171 // Enable cross-method per-FQN weight sharing in the CUDA backend so that
147172 // prefill and decode (which share KV cache and other mutable buffers /
148173 // weights) avoid duplicate GPU allocations. This is critical for fitting
@@ -170,6 +195,7 @@ int main(int argc, char** argv) {
170195 return 1 ;
171196 }
172197 }
198+ #endif
173199
174200 auto err = module ->load_method (" prefill" );
175201 if (err != Error::Ok) {
@@ -224,12 +250,16 @@ int main(int argc, char** argv) {
224250 // ---------------------------------------------------------------
225251 auto S = [](int64_t v) -> SizesType { return static_cast <SizesType>(v); };
226252
227- // Use a very small temperature for greedy to avoid division by zero
228- // while keeping the Gumbel noise negligible relative to logit differences.
253+ #ifdef EXECUTORCH_BUILD_CUDA
254+ // CUDA build: model fuses the sampler in. Pass a temperature tensor as
255+ // a third input. Use a very small temperature for greedy to avoid
256+ // division by zero while keeping the Gumbel noise negligible relative
257+ // to logit differences.
229258 float temp_val =
230259 FLAGS_temperature <= 0.0 ? 1e-6f : static_cast <float >(FLAGS_temperature);
231260 auto temp_tensor =
232261 from_blob (&temp_val, {1 }, executorch::aten::ScalarType::Float);
262+ #endif
233263
234264 // ---------------------------------------------------------------
235265 // Prefill
@@ -260,7 +290,9 @@ int main(int argc, char** argv) {
260290 std::vector<EValue> prefill_inputs;
261291 prefill_inputs.push_back (tokens_tensor);
262292 prefill_inputs.push_back (pos_tensor);
293+ #ifdef EXECUTORCH_BUILD_CUDA
263294 prefill_inputs.push_back (temp_tensor);
295+ #endif
264296
265297 auto prefill_result = module ->execute (run_method, prefill_inputs);
266298 if (prefill_result.error () != Error::Ok) {
@@ -308,7 +340,9 @@ int main(int argc, char** argv) {
308340 std::vector<EValue> decode_inputs;
309341 decode_inputs.push_back (EValue (decode_tokens));
310342 decode_inputs.push_back (EValue (decode_pos));
343+ #ifdef EXECUTORCH_BUILD_CUDA
311344 decode_inputs.push_back (EValue (temp_tensor));
345+ #endif
312346
313347 auto decode_result = module ->execute (" decode" , decode_inputs);
314348 if (decode_result.error () != Error::Ok) {
0 commit comments