@@ -37,6 +37,14 @@ DEFINE_string(
3737DEFINE_double (temperature, 0.8 , " Sampling temperature (0 = greedy)." );
3838DEFINE_int32 (max_new_tokens, 128 , " Maximum tokens to generate." );
3939DEFINE_bool (cuda_graph, false , " Enable CUDA graph for decode method." );
40+ DEFINE_int64 (
41+ top_k,
42+ -1 ,
43+ " Top-k sampling cutoff (<=0 = no-op default of vocab_size, keeps all tokens)." );
44+ DEFINE_double (
45+ top_p,
46+ 1.0 ,
47+ " Top-p (nucleus) sampling threshold. 1.0 = no-op (keeps full nucleus)." );
4048
4149namespace llm = ::executorch::extension::llm;
4250using ::executorch::extension::from_blob;
@@ -187,7 +195,7 @@ int main(int argc, char** argv) {
187195 stats.inference_start_ms = llm::time_in_ms ();
188196
189197 // ---------------------------------------------------------------
190- // Temperature tensor (shared between prefill and decode)
198+ // Sampling tensors (shared between prefill and decode)
191199 // ---------------------------------------------------------------
192200 auto S = [](int64_t v) -> SizesType { return static_cast <SizesType>(v); };
193201
@@ -198,6 +206,22 @@ int main(int argc, char** argv) {
198206 auto temp_tensor =
199207 from_blob (&temp_val, {1 }, executorch::aten::ScalarType::Float);
200208
209+ // top_k / top_p are 0-D scalar tensors matching the export-time signature
210+ // (see examples/models/qwen3_5_moe/export.py). The default flag values
211+ // (top_k = vocab_size, top_p = 1.0) are mathematical no-ops: the sort+
212+ // scatter subgraph still runs (it was traced into the graph at export
213+ // time), but produces all-False filter masks so logits pass through
214+ // unchanged. Override at runtime to enable real filtering.
215+ int64_t vocab_size = metadata.count (llm::kVocabSize )
216+ ? metadata[llm::kVocabSize ]
217+ : static_cast <int64_t >(tokenizer->vocab_size ());
218+ int64_t top_k_val = (FLAGS_top_k <= 0 ) ? vocab_size : FLAGS_top_k;
219+ float top_p_val = static_cast <float >(FLAGS_top_p);
220+ auto top_k_tensor =
221+ from_blob (&top_k_val, {}, executorch::aten::ScalarType::Long);
222+ auto top_p_tensor =
223+ from_blob (&top_p_val, {}, executorch::aten::ScalarType::Float);
224+
201225 // ---------------------------------------------------------------
202226 // Prefill
203227 // ---------------------------------------------------------------
@@ -228,6 +252,8 @@ int main(int argc, char** argv) {
228252 prefill_inputs.push_back (tokens_tensor);
229253 prefill_inputs.push_back (pos_tensor);
230254 prefill_inputs.push_back (temp_tensor);
255+ prefill_inputs.push_back (top_k_tensor);
256+ prefill_inputs.push_back (top_p_tensor);
231257
232258 auto prefill_result = module ->execute (run_method, prefill_inputs);
233259 if (prefill_result.error () != Error::Ok) {
@@ -276,6 +302,8 @@ int main(int argc, char** argv) {
276302 decode_inputs.push_back (EValue (decode_tokens));
277303 decode_inputs.push_back (EValue (decode_pos));
278304 decode_inputs.push_back (EValue (temp_tensor));
305+ decode_inputs.push_back (EValue (top_k_tensor));
306+ decode_inputs.push_back (EValue (top_p_tensor));
279307
280308 auto decode_result = module ->execute (" decode" , decode_inputs);
281309 if (decode_result.error () != Error::Ok) {
0 commit comments