99#include < gflags/gflags.h>
1010
1111#include < executorch/extension/llm/runner/text_llm_runner.h>
12+ #include < executorch/extension/module/module.h>
1213#include < executorch/runtime/platform/log.h>
1314#include < pytorch/tokenizers/hf_tokenizer.h>
1415
15- #include < optional>
1616#include < string>
17+ #include < vector>
1718
1819DEFINE_string (model_path, " " , " Model .pte file path." );
1920DEFINE_string (data_path, " " , " Data file (.ptd) for CUDA backend." );
@@ -23,7 +24,6 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
2324DEFINE_int32 (max_new_tokens, 128 , " Maximum tokens to generate." );
2425
2526namespace llm = ::executorch::extension::llm;
26- using ::executorch::runtime::Error;
2727
2828int main (int argc, char ** argv) {
2929 gflags::ParseCommandLineFlags (&argc, &argv, true );
@@ -37,6 +37,11 @@ int main(int argc, char** argv) {
3737 return 1 ;
3838 }
3939
40+ std::vector<std::string> data_files;
41+ if (!FLAGS_data_path.empty ()) {
42+ data_files.push_back (FLAGS_data_path);
43+ }
44+
4045 // Load tokenizer
4146 auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
4247 auto tok_status = tokenizer->load (FLAGS_tokenizer_path);
@@ -48,37 +53,23 @@ int main(int argc, char** argv) {
4853 return 1 ;
4954 }
5055
51- // Single-method runner: "forward" handles both prefill (T>1) and decode (T=1)
52- // via torch.cond dispatch inside the model.
53- fprintf (stderr, " Loading model from %s...\n " , FLAGS_model_path.c_str ());
54- std::optional<const std::string> data_path =
55- FLAGS_data_path.empty () ? std::nullopt
56- : std::optional<const std::string>(FLAGS_data_path);
56+ // Create LLM runner
5757 auto runner = llm::create_text_llm_runner (
58- FLAGS_model_path,
59- std::move (tokenizer),
60- data_path,
61- FLAGS_temperature);
62- fprintf (stderr, " Runner created successfully\n " );
58+ FLAGS_model_path, std::move (tokenizer), data_files, FLAGS_temperature);
59+
60+ if (runner == nullptr ) {
61+ ET_LOG (Error, " Failed to create runner" );
62+ return 1 ;
63+ }
6364
6465 // Generate
6566 llm::GenerationConfig config;
6667 config.temperature = FLAGS_temperature;
6768 config.max_new_tokens = FLAGS_max_new_tokens;
6869
69- fprintf (stderr, " Starting generation with prompt: %s\n " , FLAGS_prompt.c_str ());
70- try {
71- auto error = runner->generate (FLAGS_prompt.c_str (), config);
72- if (error != Error::Ok) {
73- fprintf (stderr, " Generation failed with error code: %d\n " , static_cast <int >(error));
74- return 1 ;
75- }
76- fprintf (stderr, " Generation completed successfully\n " );
77- } catch (const std::exception& e) {
78- fprintf (stderr, " Exception during generation: %s\n " , e.what ());
79- return 1 ;
80- } catch (...) {
81- fprintf (stderr, " Unknown exception during generation\n " );
70+ auto error = runner->generate (FLAGS_prompt.c_str (), config);
71+ if (error != executorch::runtime::Error::Ok) {
72+ ET_LOG (Error, " Generation failed" );
8273 return 1 ;
8374 }
8475
0 commit comments