Skip to content

Commit be3e950

Browse files
committed
add model temperature input cuda backend only
1 parent 08777af commit be3e950

1 file changed

Lines changed: 40 additions & 6 deletions

File tree

examples/models/qwen3_5_moe/main.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
#ifdef EXECUTORCH_BUILD_CUDA
2727
#include <cuda_runtime.h>
28+
#else
29+
#include <executorch/extension/llm/sampler/util.h>
2830
#endif
2931

3032
DEFINE_string(model_path, "", "Model .pte file path.");
@@ -37,7 +39,7 @@ DEFINE_string(
3739
"Path to file containing prompt text (overrides --prompt).");
3840
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
3941
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
40-
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
42+
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method. CUDA only.");
4143

4244
namespace llm = ::executorch::extension::llm;
4345
using ::executorch::extension::from_blob;
@@ -48,10 +50,18 @@ using ::executorch::runtime::EValue;
4850

4951
using SizesType = executorch::aten::SizesType;
5052

51-
// Read a sampled token from the model output tensor [B, 1].
52-
// The model performs Gumbel-max sampling on-device and returns a single
53-
// float token ID. This function copies it from GPU and casts to uint64.
53+
// Convert a model output tensor to the next sampled token id.
54+
//
55+
// On the CUDA build, the model fuses the sampler in (see sampler.py /
56+
// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
57+
// float tensor; we just copy that scalar back from device.
58+
//
59+
// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
60+
// of shape [B, T, V] in the model dtype (typically bf16). We sample on
61+
// CPU via the shared `llm::logits_to_token` helper, which accepts a
62+
// temperature (0 = greedy / argmax).
5463
static uint64_t read_token(const executorch::aten::Tensor& output) {
64+
#ifdef EXECUTORCH_BUILD_CUDA
5565
const void* ptr = output.const_data_ptr();
5666

5767
cudaPointerAttributes attrs;
@@ -73,6 +83,14 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
7383
memcpy(&val, ptr, sizeof(float));
7484
}
7585
return static_cast<uint64_t>(val);
86+
#else
87+
// logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 /
88+
// UInt16 dtypes. Negative temperatures are clamped to 0 (greedy).
89+
const float temp = FLAGS_temperature <= 0.0
90+
? 0.0f
91+
: static_cast<float>(FLAGS_temperature);
92+
return static_cast<uint64_t>(llm::logits_to_token(output, temp));
93+
#endif
7694
}
7795

7896
int main(int argc, char** argv) {
@@ -133,16 +151,23 @@ int main(int argc, char** argv) {
133151
}
134152
auto metadata = metadata_result.get();
135153

154+
#ifdef EXECUTORCH_BUILD_CUDA
136155
// Set CUDA graph option if requested (must be before load_method)
137156
if (FLAGS_cuda_graph) {
138157
executorch::runtime::BackendOptions<2> cuda_opts;
139158
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
140159
executorch::runtime::set_option("CudaBackend", cuda_opts.view());
141160
printf("CUDA graph enabled for decode method\n");
142161
}
162+
#else
163+
if (FLAGS_cuda_graph) {
164+
ET_LOG(Info, "--cuda_graph ignored on non-CUDA build");
165+
}
166+
#endif
143167

144168
printf("Loading methods...\n");
145169

170+
#ifdef EXECUTORCH_BUILD_CUDA
146171
// Enable cross-method per-FQN weight sharing in the CUDA backend so that
147172
// prefill and decode (which share KV cache and other mutable buffers /
148173
// weights) avoid duplicate GPU allocations. This is critical for fitting
@@ -170,6 +195,7 @@ int main(int argc, char** argv) {
170195
return 1;
171196
}
172197
}
198+
#endif
173199

174200
auto err = module->load_method("prefill");
175201
if (err != Error::Ok) {
@@ -224,12 +250,16 @@ int main(int argc, char** argv) {
224250
// ---------------------------------------------------------------
225251
auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };
226252

227-
// Use a very small temperature for greedy to avoid division by zero
228-
// while keeping the Gumbel noise negligible relative to logit differences.
253+
#ifdef EXECUTORCH_BUILD_CUDA
254+
// CUDA build: model fuses the sampler in. Pass a temperature tensor as
255+
// a third input. Use a very small temperature for greedy to avoid
256+
// division by zero while keeping the Gumbel noise negligible relative
257+
// to logit differences.
229258
float temp_val =
230259
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
231260
auto temp_tensor =
232261
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
262+
#endif
233263

234264
// ---------------------------------------------------------------
235265
// Prefill
@@ -260,7 +290,9 @@ int main(int argc, char** argv) {
260290
std::vector<EValue> prefill_inputs;
261291
prefill_inputs.push_back(tokens_tensor);
262292
prefill_inputs.push_back(pos_tensor);
293+
#ifdef EXECUTORCH_BUILD_CUDA
263294
prefill_inputs.push_back(temp_tensor);
295+
#endif
264296

265297
auto prefill_result = module->execute(run_method, prefill_inputs);
266298
if (prefill_result.error() != Error::Ok) {
@@ -308,7 +340,9 @@ int main(int argc, char** argv) {
308340
std::vector<EValue> decode_inputs;
309341
decode_inputs.push_back(EValue(decode_tokens));
310342
decode_inputs.push_back(EValue(decode_pos));
343+
#ifdef EXECUTORCH_BUILD_CUDA
311344
decode_inputs.push_back(EValue(temp_tensor));
345+
#endif
312346

313347
auto decode_result = module->execute("decode", decode_inputs);
314348
if (decode_result.error() != Error::Ok) {

0 commit comments

Comments
 (0)