Skip to content

Commit 6b8cca8

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Fix double-subtraction of pos_ in TextLLMRunner::generate()
Summary: When seq_len is set and pos_ > 0 (multi-turn conversations), max_context_len was pre-adjusted by subtracting pos_, but resolve_max_new_tokens then only subtracted num_prompt_tokens instead of the full occupied position count. This caused min(seq_len, max_context_len) to use a too-large max_context_len, producing more tokens than allowed by seq_len. Fix: use raw metadata value for max_context_len and pass pos_ (which includes prompt tokens after prefill) to resolve_max_new_tokens, matching multimodal_runner's behavior. Differential Revision: D99742232
1 parent 19bbeac commit 6b8cca8

2 files changed

Lines changed: 76 additions & 7 deletions

File tree

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,4 +497,73 @@ TEST_F(RunnerTest, GenerateEmptyWithoutPrefillFails) {
497497
EXPECT_EQ(err, Error::InvalidState);
498498
}
499499

500+
// Test that multi-turn generation with seq_len correctly accounts for pos_.
501+
// Regression test for a bug where max_context_len was pre-adjusted by pos_,
502+
// causing resolve_max_new_tokens to under-count occupied positions when
503+
// seq_len is set.
504+
TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) {
505+
auto tokenizer = createMockTokenizer();
506+
auto text_decoder_runner = createMockTextDecoderRunner();
507+
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
508+
509+
ON_CALL(*tokenizer, encode(_, _, _))
510+
.WillByDefault([&](const std::string&, int8_t, int8_t) {
511+
return ::tokenizers::Result<std::vector<uint64_t>>(
512+
std::vector<uint64_t>{1, 2, 3});
513+
});
514+
515+
ON_CALL(*text_prefiller, prefill(_, _))
516+
.WillByDefault([&](std::vector<uint64_t>& tokens, int64_t& pos) {
517+
pos += tokens.size();
518+
return Result<uint64_t>(4);
519+
});
520+
521+
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));
522+
523+
std::unique_ptr<executorch::llm::Stats> stats =
524+
std::make_unique<executorch::llm::Stats>();
525+
auto text_token_generator = createTextTokenGenerator(
526+
tokenizer.get(), text_decoder_runner.get(), stats.get());
527+
528+
auto module = std::make_unique<MockModule>();
529+
auto io_manager =
530+
std::make_unique<executorch::extension::llm::IOManager>(*module);
531+
TextLLMRunner runner(
532+
createDefaultMetadata(), // kMaxContextLen = 128
533+
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
534+
std::move(module),
535+
std::move(text_decoder_runner),
536+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
537+
text_prefiller.release()),
538+
std::move(io_manager),
539+
std::move(text_token_generator),
540+
std::move(stats));
541+
542+
runner.load();
543+
544+
// First turn: advance pos_ to 7 (3 prompt + 4 generated)
545+
GenerationConfig config1;
546+
config1.max_new_tokens = 5; // prefill generates 1, loop generates 4
547+
config1.echo = false;
548+
Error err1 = runner.generate("first turn", config1);
549+
EXPECT_EQ(err1, Error::Ok);
550+
551+
// Second turn with seq_len=20: pos_ is now 7, prompt adds 3 more → pos_=10
552+
// Correct max_new_tokens = min(20, 128) - 10 = 10
553+
// Bug would give: min(20, 128-7) - 3 = 17
554+
GenerationConfig config2;
555+
config2.seq_len = 20;
556+
config2.echo = false;
557+
558+
CallbackCounter counter;
559+
Error err2 = runner.generate(
560+
"second turn", config2, [&counter](const std::string& token) {
561+
counter.callback(token);
562+
});
563+
564+
EXPECT_EQ(err2, Error::Ok);
565+
// With correct pos_ accounting: min(20, 128) - 10 = 10 new tokens
566+
EXPECT_EQ(counter.getCount(), 10);
567+
}
568+
500569
} // namespace

extension/llm/runner/text_llm_runner.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ Error TextLLMRunner::generate(
110110
stats_->inference_start_ms = time_in_ms();
111111
shouldStop_ = false;
112112

113-
// Capture remaining KV cache capacity before prefill (pos_ will change)
114-
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
113+
int64_t max_context_len = metadata_.at(kMaxContextLen);
115114

116115
uint64_t cur_token = 0;
117116
int num_prompt_tokens = 0;
@@ -139,10 +138,11 @@ Error TextLLMRunner::generate(
139138
InvalidArgument,
140139
"Expected at least 1 prompt token");
141140
ET_CHECK_OR_RETURN_ERROR(
142-
num_prompt_tokens < max_context_len,
141+
pos_ + num_prompt_tokens < max_context_len,
143142
InvalidArgument,
144-
"num_prompt_tokens %d >= max_context_len %" PRId64
143+
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
145144
", Max seq length exceeded - please increase max seq len value in your export script",
145+
pos_,
146146
num_prompt_tokens,
147147
max_context_len);
148148

@@ -168,10 +168,10 @@ Error TextLLMRunner::generate(
168168
prefill_next_token_.reset();
169169
}
170170

171-
// Determine max_new_tokens using the GenerationConfig's resolve method,
172-
// then subtract pos_ for max_new_tokens.
171+
// Resolve max_new_tokens. pos_ now reflects all occupied positions
172+
// (including prompt tokens just prefilled).
173173
int max_new_tokens =
174-
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);
174+
config.resolve_max_new_tokens(max_context_len, pos_);
175175

176176
ET_LOG(
177177
Info,

0 commit comments

Comments
 (0)