3434#include < cuda_runtime.h>
3535#endif
3636
37- DEFINE_string (model_path, " " , " Path to model .pte." );
38- DEFINE_string (data_path, " " , " Path to model .ptd ( CUDA tensor data) ." );
37+ DEFINE_string (model_path, " " , " Model .pte file path ." );
38+ DEFINE_string (data_path, " " , " Data file ( .ptd) for CUDA backend ." );
3939DEFINE_string (tokenizer_path, " " , " HuggingFace tokenizer.json path." );
4040DEFINE_string (prompt, " Hello" , " Prompt text." );
4141DEFINE_string (
4242 prompt_file,
4343 " " ,
44- " Optional path to a file with the prompt text (overrides --prompt)." );
44+ " Path to file containing prompt text (overrides --prompt)." );
4545DEFINE_double (temperature, 0.8 , " Sampling temperature (0 = near-greedy)." );
4646DEFINE_int32 (max_new_tokens, 128 , " Maximum tokens to generate." );
4747DEFINE_bool (
4848 cuda_graph,
4949 false ,
50- " Enable CUDA graph capture for the decode method." );
50+ " Enable CUDA graph capture for the decode method. CUDA only. " );
5151
5252namespace llm = ::executorch::extension::llm;
5353using ::executorch::extension::from_blob;
@@ -57,8 +57,6 @@ using ::executorch::runtime::EValue;
5757
5858using SizesType = executorch::aten::SizesType;
5959
60- // The model performs sampling on-device and returns a [B, 1] float tensor
61- // holding a token ID. Copy it to host and convert to uint64.
6260static uint64_t read_token (const executorch::aten::Tensor& output) {
6361 const void * ptr = output.const_data_ptr ();
6462 float val = 0 .0f ;
@@ -135,12 +133,14 @@ int main(int argc, char** argv) {
135133 /* temp_allocator=*/ nullptr ,
136134 /* share_memory_arenas=*/ true );
137135
136+ // Get metadata
138137 auto metadata_result = llm::get_llm_metadata (tokenizer.get (), module .get ());
139138 if (metadata_result.error () != Error::Ok) {
140139 ET_LOG (Error, " Failed to read model metadata" );
141140 return 1 ;
142141 }
143142
143+ #ifdef EXECUTORCH_BUILD_CUDA
144144 if (FLAGS_cuda_graph) {
145145 executorch::runtime::BackendOptions<2 > cuda_opts;
146146 cuda_opts.set_option (" enable_cuda_graph_for_method" , " decode" );
@@ -154,14 +154,30 @@ int main(int argc, char** argv) {
154154 // load_method.
155155 {
156156 executorch::runtime::BackendOptions<1 > backend_options;
157- if (backend_options.set_option (" weight_sharing_across_methods" , true ) !=
158- Error::Ok ||
159- executorch::runtime::set_option (
160- " CudaBackend" , backend_options.view ()) != Error::Ok) {
161- ET_LOG (Error, " Failed to enable weight_sharing_across_methods" );
157+ auto set_err =
158+ backend_options.set_option (" weight_sharing_across_methods" , true );
159+ if (set_err != Error::Ok) {
160+ ET_LOG (
161+ Error,
162+ " Failed to construct weight_sharing_across_methods option: %d" ,
163+ static_cast <int >(set_err));
164+ return 1 ;
165+ }
166+ auto opt_err =
167+ executorch::runtime::set_option (" CudaBackend" , backend_options.view ());
168+ if (opt_err != Error::Ok) {
169+ ET_LOG (
170+ Error,
171+ " Failed to enable weight_sharing_across_methods: %d" ,
172+ static_cast <int >(opt_err));
162173 return 1 ;
163174 }
164175 }
176+ #else
177+ if (FLAGS_cuda_graph) {
178+ ET_LOG (Info, " --cuda_graph ignored on non-CUDA build" );
179+ }
180+ #endif
165181
166182 printf (" Loading methods...\n " );
167183 if (module ->load_method (" prefill" ) != Error::Ok) {
@@ -181,6 +197,7 @@ int main(int argc, char** argv) {
181197
182198 auto eos_ids = llm::get_eos_ids (tokenizer.get (), module .get ());
183199
200+ // Read prompt from file or flag
184201 std::string prompt_text = FLAGS_prompt;
185202 if (!FLAGS_prompt_file.empty ()) {
186203 std::ifstream f (FLAGS_prompt_file);
@@ -189,10 +206,11 @@ int main(int argc, char** argv) {
189206 Error, " Failed to open prompt file: %s" , FLAGS_prompt_file.c_str ());
190207 return 1 ;
191208 }
192- prompt_text. assign (
209+ prompt_text = std::string (
193210 (std::istreambuf_iterator<char >(f)), std::istreambuf_iterator<char >());
194211 }
195212
213+ // Encode prompt
196214 auto encode_result = tokenizer->encode (prompt_text);
197215 if (!encode_result.ok ()) {
198216 ET_LOG (Error, " Failed to encode prompt" );
@@ -207,49 +225,66 @@ int main(int argc, char** argv) {
207225
208226 auto S = [](int64_t v) -> SizesType { return static_cast <SizesType>(v); };
209227
210- // Temperature: clamp 0 to a tiny epsilon so the divide in the exported
211- // sampler stays well-defined. Gumbel noise then becomes negligible
212- // relative to logit gaps and we get effectively-greedy sampling.
228+ #ifdef EXECUTORCH_BUILD_CUDA
229+ // CUDA build: model fuses the sampler. Pass temperature as a third input.
213230 float temp_val =
214231 FLAGS_temperature <= 0.0 ? 1e-6f : static_cast <float >(FLAGS_temperature);
215232 auto temp_tensor =
216233 from_blob (&temp_val, {1 }, executorch::aten::ScalarType::Float);
234+ #endif
217235
218236 // ---------------------------------------------------------------
219- // Prefill
237+ // Prefill (chunked to respect ring-buffer KV cache limit)
220238 // ---------------------------------------------------------------
221- std::string run_method = " prefill" ;
222- if (num_prompt_tokens == 1 ) {
223- // prefill was exported with min seq_len=2; decode handles T==1.
224- run_method = " decode" ;
239+ // Sliding layers use a ring buffer sized to 2×sliding_window. A single
240+ // prefill call must not exceed this size, otherwise index_copy_ with
241+ // wrapped indices produces non-deterministic results on CUDA.
242+ int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen ] - 1 ;
243+ {
244+ auto get_result = module ->get (" get_max_prefill_chunk" );
245+ if (get_result.ok ()) {
246+ max_prefill_chunk = get_result->toScalar ().to <int64_t >();
247+ }
225248 }
226249
227- std::vector<int64_t > token_data (prompt_tokens.begin (), prompt_tokens.end ());
228- std::vector<int64_t > pos_data (num_prompt_tokens);
229- for (int64_t i = 0 ; i < num_prompt_tokens; i++) {
230- pos_data[i] = i;
231- }
232- auto tokens_tensor = from_blob (
233- token_data.data (),
234- {1 , S (num_prompt_tokens)},
235- executorch::aten::ScalarType::Long);
236- auto pos_tensor = from_blob (
237- pos_data.data (),
238- {S (num_prompt_tokens)},
239- executorch::aten::ScalarType::Long);
240-
241- std::vector<EValue> prefill_inputs = {
242- EValue (tokens_tensor),
243- EValue (pos_tensor),
244- EValue (temp_tensor),
245- };
246-
247- auto prefill_result = module ->execute (run_method, prefill_inputs);
248- if (prefill_result.error () != Error::Ok) {
249- ET_LOG (Error, " %s failed" , run_method.c_str ());
250- return 1 ;
250+ uint64_t cur_token = 0 ;
251+ int64_t prefill_pos = 0 ;
252+ while (prefill_pos < num_prompt_tokens) {
253+ int64_t chunk_len =
254+ std::min (num_prompt_tokens - prefill_pos, max_prefill_chunk);
255+
256+ std::string run_method = (chunk_len == 1 ) ? " decode" : " prefill" ;
257+
258+ std::vector<int64_t > token_data (
259+ prompt_tokens.begin () + prefill_pos,
260+ prompt_tokens.begin () + prefill_pos + chunk_len);
261+ std::vector<int64_t > pos_data (chunk_len);
262+ for (int64_t i = 0 ; i < chunk_len; i++) {
263+ pos_data[i] = prefill_pos + i;
264+ }
265+ auto tokens_tensor = from_blob (
266+ token_data.data (),
267+ {1 , S (chunk_len)},
268+ executorch::aten::ScalarType::Long);
269+ auto pos_tensor = from_blob (
270+ pos_data.data (), {S (chunk_len)}, executorch::aten::ScalarType::Long);
271+
272+ std::vector<EValue> prefill_inputs;
273+ prefill_inputs.push_back (EValue (tokens_tensor));
274+ prefill_inputs.push_back (EValue (pos_tensor));
275+ #ifdef EXECUTORCH_BUILD_CUDA
276+ prefill_inputs.push_back (EValue (temp_tensor));
277+ #endif
278+
279+ auto prefill_result = module ->execute (run_method, prefill_inputs);
280+ if (prefill_result.error () != Error::Ok) {
281+ ET_LOG (
282+ Error, " %s failed at pos %" PRId64, run_method.c_str (), prefill_pos);
283+ return 1 ;
284+ }
285+ cur_token = read_token (prefill_result.get ()[0 ].toTensor ());
286+ prefill_pos += chunk_len;
251287 }
252- uint64_t cur_token = read_token (prefill_result.get ()[0 ].toTensor ());
253288
254289 stats.prompt_eval_end_ms = llm::time_in_ms ();
255290 double prefill_ms =
@@ -261,8 +296,9 @@ int main(int argc, char** argv) {
261296 num_prompt_tokens * 1000.0 / prefill_ms);
262297
263298#ifdef EXECUTORCH_BUILD_CUDA
264- // Make prefill's writes to the shared KV cache visible before decode
265- // potentially runs on a different stream.
299+ // Synchronize CUDA device to ensure prefill's writes to shared mutable
300+ // buffers (KV cache) are visible to the decode method, which may run on
301+ // a different CUDA stream.
266302 cudaDeviceSynchronize ();
267303#endif
268304
@@ -282,11 +318,12 @@ int main(int argc, char** argv) {
282318 decode_token_data[0 ] = static_cast <int64_t >(cur_token);
283319 decode_pos_data[0 ] = pos;
284320
285- std::vector<EValue> decode_inputs = {
286- EValue (decode_tokens),
287- EValue (decode_pos),
288- EValue (temp_tensor),
289- };
321+ std::vector<EValue> decode_inputs;
322+ decode_inputs.push_back (EValue (decode_tokens));
323+ decode_inputs.push_back (EValue (decode_pos));
324+ #ifdef EXECUTORCH_BUILD_CUDA
325+ decode_inputs.push_back (EValue (temp_tensor));
326+ #endif
290327
291328 auto decode_result = module ->execute (" decode" , decode_inputs);
292329 if (decode_result.error () != Error::Ok) {
0 commit comments