Skip to content

Commit e79f101

Browse files
committed
Add BOS/EOS token handling to C++ runner
Gemma's HuggingFace tokenizer does not auto-prepend BOS. Without it the model's logits collapse. Add --bos_id (default 2) to prepend and --eos_id (default 1) as a fallback stop token.
1 parent 0ea6812 commit e79f101

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

examples/models/gemma4_31b/main.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,22 @@
3030
#include <string>
3131
#include <vector>
3232

33+
#include <executorch/runtime/platform/platform.h>
34+
#include <executorch/runtime/platform/types.h>
35+
extern "C" void et_pal_emit_log_message(
36+
ET_UNUSED et_timestamp_t timestamp,
37+
et_pal_log_level_t level,
38+
const char* filename,
39+
ET_UNUSED const char* function,
40+
size_t line,
41+
const char* message,
42+
ET_UNUSED size_t length) {
43+
if (level < 'W') {
44+
return;
45+
}
46+
fprintf(stderr, "%c [%s:%zu] %s\n", (char)level, filename, line, message);
47+
}
48+
3349
#ifdef EXECUTORCH_BUILD_CUDA
3450
#include <cuda_runtime.h>
3551
#endif
@@ -44,6 +60,8 @@ DEFINE_string(
4460
"Path to file containing prompt text (overrides --prompt).");
4561
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
4662
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
63+
DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
64+
DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
4765
DEFINE_bool(
4866
cuda_graph,
4967
false,
@@ -196,6 +214,7 @@ int main(int argc, char** argv) {
196214
#endif
197215

198216
auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get());
217+
eos_ids.insert(static_cast<uint64_t>(FLAGS_eos_id));
199218

200219
// Read prompt from file or flag
201220
std::string prompt_text = FLAGS_prompt;
@@ -217,6 +236,9 @@ int main(int argc, char** argv) {
217236
return 1;
218237
}
219238
auto prompt_tokens = std::move(*encode_result);
239+
// Gemma models require BOS at the start of the sequence.
240+
prompt_tokens.insert(
241+
prompt_tokens.begin(), static_cast<uint64_t>(FLAGS_bos_id));
220242
int64_t num_prompt_tokens = static_cast<int64_t>(prompt_tokens.size());
221243
printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens);
222244
stats.num_prompt_tokens = num_prompt_tokens;

0 commit comments

Comments
 (0)