Skip to content

Commit 1d4608b

Browse files
committed
add model temperature input cuda backend only
1 parent be3e950 commit 1d4608b

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

examples/models/qwen3_5_moe/main.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ DEFINE_string(
3939
"Path to file containing prompt text (overrides --prompt).");
4040
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
4141
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
42-
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method. CUDA only.");
42+
DEFINE_bool(
43+
cuda_graph,
44+
false,
45+
"Enable CUDA graph for decode method. CUDA only.");
4346

4447
namespace llm = ::executorch::extension::llm;
4548
using ::executorch::extension::from_blob;
@@ -86,9 +89,8 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
8689
#else
8790
// logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 /
8891
// UInt16 dtypes. Negative temperatures are clamped to 0 (greedy).
89-
const float temp = FLAGS_temperature <= 0.0
90-
? 0.0f
91-
: static_cast<float>(FLAGS_temperature);
92+
const float temp =
93+
FLAGS_temperature <= 0.0 ? 0.0f : static_cast<float>(FLAGS_temperature);
9294
return static_cast<uint64_t>(llm::logits_to_token(output, temp));
9395
#endif
9496
}

0 commit comments

Comments
 (0)