Skip to content

Commit eff4294

Browse files
committed
add top-p and top-k arg
1 parent 6f411af commit eff4294

1 file changed

Lines changed: 29 additions & 1 deletion

File tree

examples/models/qwen3_5_moe/main.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ DEFINE_string(
3737
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
3838
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
3939
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
40+
DEFINE_int64(
41+
top_k,
42+
-1,
43+
"Top-k sampling cutoff (<=0 = no-op default of vocab_size, keeps all tokens).");
44+
DEFINE_double(
45+
top_p,
46+
1.0,
47+
"Top-p (nucleus) sampling threshold. 1.0 = no-op (keeps full nucleus).");
4048

4149
namespace llm = ::executorch::extension::llm;
4250
using ::executorch::extension::from_blob;
@@ -187,7 +195,7 @@ int main(int argc, char** argv) {
187195
stats.inference_start_ms = llm::time_in_ms();
188196

189197
// ---------------------------------------------------------------
190-
// Temperature tensor (shared between prefill and decode)
198+
// Sampling tensors (shared between prefill and decode)
191199
// ---------------------------------------------------------------
192200
auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };
193201

@@ -198,6 +206,22 @@ int main(int argc, char** argv) {
198206
auto temp_tensor =
199207
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
200208

209+
// top_k / top_p are 0-D scalar tensors matching the export-time signature
210+
// (see examples/models/qwen3_5_moe/export.py). The default flag values
211+
// (top_k = vocab_size, top_p = 1.0) are mathematical no-ops: the sort+
212+
// scatter subgraph still runs (it was traced into the graph at export
213+
// time), but produces all-False filter masks so logits pass through
214+
// unchanged. Override at runtime to enable real filtering.
215+
int64_t vocab_size = metadata.count(llm::kVocabSize)
216+
? metadata[llm::kVocabSize]
217+
: static_cast<int64_t>(tokenizer->vocab_size());
218+
int64_t top_k_val = (FLAGS_top_k <= 0) ? vocab_size : FLAGS_top_k;
219+
float top_p_val = static_cast<float>(FLAGS_top_p);
220+
auto top_k_tensor =
221+
from_blob(&top_k_val, {}, executorch::aten::ScalarType::Long);
222+
auto top_p_tensor =
223+
from_blob(&top_p_val, {}, executorch::aten::ScalarType::Float);
224+
201225
// ---------------------------------------------------------------
202226
// Prefill
203227
// ---------------------------------------------------------------
@@ -228,6 +252,8 @@ int main(int argc, char** argv) {
228252
prefill_inputs.push_back(tokens_tensor);
229253
prefill_inputs.push_back(pos_tensor);
230254
prefill_inputs.push_back(temp_tensor);
255+
prefill_inputs.push_back(top_k_tensor);
256+
prefill_inputs.push_back(top_p_tensor);
231257

232258
auto prefill_result = module->execute(run_method, prefill_inputs);
233259
if (prefill_result.error() != Error::Ok) {
@@ -276,6 +302,8 @@ int main(int argc, char** argv) {
276302
decode_inputs.push_back(EValue(decode_tokens));
277303
decode_inputs.push_back(EValue(decode_pos));
278304
decode_inputs.push_back(EValue(temp_tensor));
305+
decode_inputs.push_back(EValue(top_k_tensor));
306+
decode_inputs.push_back(EValue(top_p_tensor));
279307

280308
auto decode_result = module->execute("decode", decode_inputs);
281309
if (decode_result.error() != Error::Ok) {

0 commit comments

Comments
 (0)