Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions extension/llm/runner/test/test_text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,73 @@ TEST_F(RunnerTest, NonKvCacheGenerateCompletesSuccessfully) {
EXPECT_EQ(step_count, max_new_tokens);
}

// Test that multi-turn generation with seq_len correctly accounts for pos_.
// Regression test for a bug where max_context_len was pre-adjusted by pos_,
// causing resolve_max_new_tokens to under-count occupied positions when
// seq_len is set.
TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) {
auto tokenizer = createMockTokenizer();
auto text_decoder_runner = createMockTextDecoderRunner();
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());

ON_CALL(*tokenizer, encode(_, _, _))
.WillByDefault([&](const std::string&, int8_t, int8_t) {
return ::tokenizers::Result<std::vector<uint64_t>>(
std::vector<uint64_t>{1, 2, 3});
});

ON_CALL(*text_prefiller, prefill(_, _))
.WillByDefault([&](std::vector<uint64_t>& tokens, int64_t& pos) {
pos += tokens.size();
return Result<uint64_t>(4);
});

ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));

std::unique_ptr<executorch::llm::Stats> stats =
std::make_unique<executorch::llm::Stats>();
auto text_token_generator = createTextTokenGenerator(
tokenizer.get(), text_decoder_runner.get(), stats.get());

auto module = std::make_unique<MockModule>();
auto io_manager =
std::make_unique<executorch::extension::llm::IOManager>(*module);
TextLLMRunner runner(
createDefaultMetadata(), // kMaxContextLen = 128
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
std::move(module),
std::move(text_decoder_runner),
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
text_prefiller.release()),
std::move(io_manager),
std::move(text_token_generator),
std::move(stats));

runner.load();

// First turn: advance pos_ to 7 (3 prompt + 4 generated)
GenerationConfig config1;
config1.max_new_tokens = 5; // prefill generates 1, loop generates 4
config1.echo = false;
Error err1 = runner.generate("first turn", config1);
EXPECT_EQ(err1, Error::Ok);

// Second turn with seq_len=20: pos_ is now 7, prompt adds 3 more → pos_=10
// Correct max_new_tokens = min(20, 128) - 10 = 10
// Bug would give: min(20, 128-7) - 3 = 17
GenerationConfig config2;
config2.seq_len = 20;
config2.echo = false;

CallbackCounter counter;
Error err2 = runner.generate(
"second turn", config2, [&counter](const std::string& token) {
counter.callback(token);
});

EXPECT_EQ(err2, Error::Ok);
// With correct pos_ accounting: min(20, 128) - 10 = 10 new tokens
EXPECT_EQ(counter.getCount(), 10);
}

} // namespace
15 changes: 7 additions & 8 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ Error TextLLMRunner::generate(
stats_->inference_start_ms = time_in_ms();
shouldStop_ = false;

// Capture remaining KV cache capacity before prefill (pos_ will change)
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
int64_t max_context_len = metadata_.at(kMaxContextLen);

uint64_t cur_token = 0;
int num_prompt_tokens = 0;
Expand Down Expand Up @@ -139,10 +138,11 @@ Error TextLLMRunner::generate(
InvalidArgument,
"Expected at least 1 prompt token");
ET_CHECK_OR_RETURN_ERROR(
num_prompt_tokens < max_context_len,
pos_ + num_prompt_tokens < max_context_len,
InvalidArgument,
"num_prompt_tokens %d >= max_context_len %" PRId64
"pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64
", Max seq length exceeded - please increase max seq len value in your export script",
pos_,
num_prompt_tokens,
max_context_len);

Expand All @@ -168,10 +168,9 @@ Error TextLLMRunner::generate(
prefill_next_token_.reset();
}

// Determine max_new_tokens using the GenerationConfig's resolve method,
// then subtract pos_ for max_new_tokens.
int max_new_tokens =
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);
// Resolve max_new_tokens. pos_ now reflects all occupied positions
// (including prompt tokens just prefilled).
int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);

Comment on lines +171 to 174
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerationConfig::resolve_max_new_tokens() takes int32_t parameters documented as num_prompt_tokens, but this call passes pos_ (an int64_t occupied-position count). This relies on implicit narrowing conversions and on a broader interpretation of the parameter than the API/docstring (and the pybinding arg name) suggests. Consider updating resolve_max_new_tokens to accept an int64_t occupied token count (or adding a new helper with clearer naming) and adjusting the documentation/bindings to avoid truncation risk and confusion.

Copilot uses AI. Check for mistakes.
ET_LOG(
Info,
Expand Down
Loading