3030#include < string>
3131#include < vector>
3232
33+ #include < executorch/runtime/platform/platform.h>
34+ #include < executorch/runtime/platform/types.h>
35+ extern " C" void et_pal_emit_log_message (
36+ ET_UNUSED et_timestamp_t timestamp,
37+ et_pal_log_level_t level,
38+ const char * filename,
39+ ET_UNUSED const char * function,
40+ size_t line,
41+ const char * message,
42+ ET_UNUSED size_t length) {
43+ if (level < ' W' ) {
44+ return ;
45+ }
46+ fprintf (stderr, " %c [%s:%zu] %s\n " , (char )level, filename, line, message);
47+ }
48+
3349#ifdef EXECUTORCH_BUILD_CUDA
3450#include < cuda_runtime.h>
3551#endif
@@ -44,6 +60,8 @@ DEFINE_string(
4460 " Path to file containing prompt text (overrides --prompt)." );
4561DEFINE_double (temperature, 0.8 , " Sampling temperature (0 = near-greedy)." );
4662DEFINE_int32 (max_new_tokens, 128 , " Maximum tokens to generate." );
63+ DEFINE_int32 (bos_id, 2 , " BOS token id to prepend (Gemma convention: 2)." );
64+ DEFINE_int32 (eos_id, 1 , " EOS token id (Gemma convention: 1)." );
4765DEFINE_bool (
4866 cuda_graph,
4967 false ,
@@ -196,6 +214,7 @@ int main(int argc, char** argv) {
196214#endif
197215
198216 auto eos_ids = llm::get_eos_ids (tokenizer.get (), module .get ());
217+ eos_ids.insert (static_cast <uint64_t >(FLAGS_eos_id));
199218
200219 // Read prompt from file or flag
201220 std::string prompt_text = FLAGS_prompt;
@@ -217,6 +236,9 @@ int main(int argc, char** argv) {
217236 return 1 ;
218237 }
219238 auto prompt_tokens = std::move (*encode_result);
239+ // Gemma models require BOS at the start of the sequence.
240+ prompt_tokens.insert (
241+ prompt_tokens.begin (), static_cast <uint64_t >(FLAGS_bos_id));
220242 int64_t num_prompt_tokens = static_cast <int64_t >(prompt_tokens.size ());
221243 printf (" Prompt tokens: %" PRId64 " \n " , num_prompt_tokens);
222244 stats.num_prompt_tokens = num_prompt_tokens;
0 commit comments