Skip to content

Commit 009b11d

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Fix double-subtraction of pos_ in TextLLMRunner::generate() (#18727)
Summary: When seq_len is set and pos_ > 0 (multi-turn conversations), max_context_len was pre-adjusted by subtracting pos_, but resolve_max_new_tokens then only subtracted num_prompt_tokens instead of the full occupied position count. This caused min(seq_len, max_context_len) to use a too-large max_context_len, producing more tokens than allowed by seq_len. Fix: use raw metadata value for max_context_len and pass pos_ (which includes prompt tokens after prefill) to resolve_max_new_tokens, matching multimodal_runner's behavior. Differential Revision: D99742232
1 parent 19f7ff2 commit 009b11d

2 files changed

Lines changed: 76 additions & 8 deletions

File tree

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,4 +549,73 @@ TEST_F(RunnerTest, NonKvCacheGenerateCompletesSuccessfully) {
549549
EXPECT_EQ(step_count, max_new_tokens);
550550
}
551551

552+
// Test that multi-turn generation with seq_len correctly accounts for pos_.
553+
// Regression test for a bug where max_context_len was pre-adjusted by pos_,
554+
// causing resolve_max_new_tokens to under-count occupied positions when
555+
// seq_len is set.
556+
TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) {
557+
auto tokenizer = createMockTokenizer();
558+
auto text_decoder_runner = createMockTextDecoderRunner();
559+
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
560+
561+
ON_CALL(*tokenizer, encode(_, _, _))
562+
.WillByDefault([&](const std::string&, int8_t, int8_t) {
563+
return ::tokenizers::Result<std::vector<uint64_t>>(
564+
std::vector<uint64_t>{1, 2, 3});
565+
});
566+
567+
ON_CALL(*text_prefiller, prefill(_, _))
568+
.WillByDefault([&](std::vector<uint64_t>& tokens, int64_t& pos) {
569+
pos += tokens.size();
570+
return Result<uint64_t>(4);
571+
});
572+
573+
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));
574+
575+
std::unique_ptr<executorch::llm::Stats> stats =
576+
std::make_unique<executorch::llm::Stats>();
577+
auto text_token_generator = createTextTokenGenerator(
578+
tokenizer.get(), text_decoder_runner.get(), stats.get());
579+
580+
auto module = std::make_unique<MockModule>();
581+
auto io_manager =
582+
std::make_unique<executorch::extension::llm::IOManager>(*module);
583+
TextLLMRunner runner(
584+
createDefaultMetadata(), // kMaxContextLen = 128
585+
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
586+
std::move(module),
587+
std::move(text_decoder_runner),
588+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
589+
text_prefiller.release()),
590+
std::move(io_manager),
591+
std::move(text_token_generator),
592+
std::move(stats));
593+
594+
runner.load();
595+
596+
// First turn: advance pos_ to 7 (3 prompt + 4 generated)
597+
GenerationConfig config1;
598+
config1.max_new_tokens = 5; // prefill generates 1, loop generates 4
599+
config1.echo = false;
600+
Error err1 = runner.generate("first turn", config1);
601+
EXPECT_EQ(err1, Error::Ok);
602+
603+
// Second turn with seq_len=20: pos_ is now 7, prompt adds 3 more → pos_=10
604+
// Correct max_new_tokens = min(20, 128) - 10 = 10
605+
// Bug would give: min(20, 128-7) - 3 = 17
606+
GenerationConfig config2;
607+
config2.seq_len = 20;
608+
config2.echo = false;
609+
610+
CallbackCounter counter;
611+
Error err2 = runner.generate(
612+
"second turn", config2, [&counter](const std::string& token) {
613+
counter.callback(token);
614+
});
615+
616+
EXPECT_EQ(err2, Error::Ok);
617+
// With correct pos_ accounting: min(20, 128) - 10 = 10 new tokens
618+
EXPECT_EQ(counter.getCount(), 10);
619+
}
620+
552621
} // namespace

extension/llm/runner/text_llm_runner.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ Error TextLLMRunner::generate(
110110
stats_->inference_start_ms = time_in_ms();
111111
shouldStop_ = false;
112112

113-
// Capture remaining KV cache capacity before prefill (pos_ will change)
114-
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
113+
int64_t max_context_len = metadata_.at(kMaxContextLen);
115114

116115
uint64_t cur_token = 0;
117116
int num_prompt_tokens = 0;
@@ -139,10 +138,11 @@ Error TextLLMRunner::generate(
139138
InvalidArgument,
140139
"Expected at least 1 prompt token");
141140
ET_CHECK_OR_RETURN_ERROR(
142-
num_prompt_tokens < max_context_len,
141+
pos_ + num_prompt_tokens < max_context_len,
143142
InvalidArgument,
144-
"num_prompt_tokens %d >= max_context_len %" PRId64
143+
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
145144
", Max seq length exceeded - please increase max seq len value in your export script",
145+
pos_,
146146
num_prompt_tokens,
147147
max_context_len);
148148

@@ -168,10 +168,9 @@ Error TextLLMRunner::generate(
168168
prefill_next_token_.reset();
169169
}
170170

171-
// Determine max_new_tokens using the GenerationConfig's resolve method,
172-
// then subtract pos_ for max_new_tokens.
173-
int max_new_tokens =
174-
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);
171+
// Resolve max_new_tokens. pos_ now reflects all occupied positions
172+
// (including prompt tokens just prefilled).
173+
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
175174

176175
ET_LOG(
177176
Info,

0 commit comments

Comments
 (0)