Skip to content

Commit 427f846

Browse files
authored
Add ignore_eos option to GenerationConfig
Differential Revision: D91183072 Pull Request resolved: pytorch#16912
1 parent d277202 commit 427f846

7 files changed

Lines changed: 24 additions & 1 deletion

File tree

examples/models/llama/main.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ DEFINE_int32(
6767

6868
DEFINE_bool(warmup, false, "Whether to run a warmup run.");
6969

70+
DEFINE_bool(
71+
ignore_eos,
72+
false,
73+
"Whether to ignore EOS token and continue generating until max_new_tokens is reached.");
74+
7075
DEFINE_string(
7176
etdump_path,
7277
"etdump.in",
@@ -165,6 +170,8 @@ int32_t main(int32_t argc, char** argv) {
165170
executorch::extension::llm::GenerationConfig config{
166171
.temperature = temperature};
167172

173+
config.ignore_eos = FLAGS_ignore_eos;
174+
168175
if (FLAGS_max_new_tokens != -1) {
169176
config.max_new_tokens = FLAGS_max_new_tokens;
170177
} else {

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ void start_runner(
238238
};
239239
executorch::extension::llm::GenerationConfig config{
240240
true,
241+
false,
241242
-1,
242243
false,
243244
FLAGS_seq_len,

examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ void start_multimodal_runner(
301301
// Configure generation
302302
executorch::extension::llm::GenerationConfig config{
303303
true,
304+
false,
304305
-1,
305306
false,
306307
FLAGS_seq_len,

extension/llm/runner/irunner.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ struct GenerationConfig {
2828
// Whether to echo the input prompt in the output
2929
bool echo = true;
3030

31+
// Whether to ignore EOS token and continue generating until max_new_tokens
32+
bool ignore_eos = false;
33+
3134
// Maximum number of new tokens to generate
3235
// If the max_context_len metadata that's serialized in the .pte file exists,
3336
// then the number of prompt tokens + max_new_tokens won't exceed

extension/llm/runner/multimodal_runner.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ Error MultimodalRunner::generate(
194194
"Max new tokens %d is less than or equal to 0",
195195
max_new_tokens);
196196

197+
// Set ignore_eos based on config
198+
text_token_generator_->set_ignore_eos(config.ignore_eos);
199+
197200
// Generate tokens using the text token generator
198201
std::vector<uint64_t> prompt_tokens = {prefill_next_token};
199202
auto generate_result = text_token_generator_->generate(

extension/llm/runner/text_llm_runner.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ Error TextLLMRunner::generate(
193193
// start the main loop
194194
prompt_tokens.push_back(cur_token);
195195

196+
// Set ignore_eos based on config
197+
text_token_generator_->set_ignore_eos(config.ignore_eos);
198+
196199
// Generate max_new_tokens - 1 because prefill already generated 1 token.
197200
auto generate_result = text_token_generator_->generate(
198201
prompt_tokens,

extension/llm/runner/text_token_generator.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class ET_EXPERIMENTAL TextTokenGenerator {
3232
use_kv_cache_(use_kv_cache),
3333
stats_(stats) {}
3434

35+
void set_ignore_eos(bool ignore_eos) {
36+
ignore_eos_ = ignore_eos;
37+
}
38+
3539
virtual ~TextTokenGenerator() = default;
3640

3741
/**
@@ -125,7 +129,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
125129
}
126130

127131
// data-dependent terminating condition: we have n_eos_ number of EOS
128-
if (eos_ids_->find(cur_token) != eos_ids_->end()) {
132+
if (!ignore_eos_ && eos_ids_->find(cur_token) != eos_ids_->end()) {
129133
printf("\n");
130134
ET_LOG(Info, "\nReached to the end of generation");
131135
break;
@@ -169,6 +173,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
169173
TextDecoderRunner* text_decoder_runner_;
170174
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
171175
bool use_kv_cache_;
176+
bool ignore_eos_ = false;
172177

173178
// state machine
174179
bool should_stop_ = false;

0 commit comments

Comments
 (0)