diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 7e9468a5f01..0faf98985e0 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -549,4 +549,73 @@ TEST_F(RunnerTest, NonKvCacheGenerateCompletesSuccessfully) { EXPECT_EQ(step_count, max_new_tokens); } +// Test that multi-turn generation with seq_len correctly accounts for pos_. +// Regression test for a bug where max_context_len was pre-adjusted by pos_, +// causing resolve_max_new_tokens to under-count occupied positions when +// seq_len is set. +TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) { + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + ON_CALL(*tokenizer, encode(_, _, _)) + .WillByDefault([&](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + ON_CALL(*text_prefiller, prefill(_, _)) + .WillByDefault([&](std::vector& tokens, int64_t& pos) { + pos += tokens.size(); + return Result(4); + }); + + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + auto module = std::make_unique(); + auto io_manager = + std::make_unique(*module); + TextLLMRunner runner( + createDefaultMetadata(), // kMaxContextLen = 128 + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::move(module), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(io_manager), + std::move(text_token_generator), + std::move(stats)); + + runner.load(); + + // First turn: advance pos_ to 7 (3 prompt + 4 generated) + GenerationConfig config1; + config1.max_new_tokens = 5; // prefill generates 1, loop generates 4 + config1.echo = false; + Error err1 = runner.generate("first turn", config1); + EXPECT_EQ(err1, Error::Ok); + + // Second turn with seq_len=20: pos_ is now 7, prompt adds 3 more → pos_=10 + // Correct max_new_tokens = min(20, 128) - 10 = 10 + // Bug would give: min(20, 128-7) - 3 = 17 + GenerationConfig config2; + config2.seq_len = 20; + config2.echo = false; + + CallbackCounter counter; + Error err2 = runner.generate( + "second turn", config2, [&counter](const std::string& token) { + counter.callback(token); + }); + + EXPECT_EQ(err2, Error::Ok); + // With correct pos_ accounting: min(20, 128) - 10 = 10 new tokens + EXPECT_EQ(counter.getCount(), 10); +} + } // namespace diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 45ab6edd79c..e8b37ba8863 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -110,8 +110,7 @@ Error TextLLMRunner::generate( stats_->inference_start_ms = time_in_ms(); shouldStop_ = false; - // Capture remaining KV cache capacity before prefill (pos_ will change) - int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; + int64_t max_context_len = metadata_.at(kMaxContextLen); uint64_t cur_token = 0; int num_prompt_tokens = 0; @@ -139,10 +138,11 @@ Error TextLLMRunner::generate( InvalidArgument, "Expected at least 1 prompt token"); ET_CHECK_OR_RETURN_ERROR( - num_prompt_tokens < max_context_len, + pos_ + num_prompt_tokens < max_context_len, InvalidArgument, - "num_prompt_tokens %d >= max_context_len %" PRId64 + "pos_ %" PRId64 " + num_prompt_tokens %d >= max_context_len %" PRId64 ", Max seq length exceeded - please increase max seq len value in your export script", + pos_, num_prompt_tokens, max_context_len); @@ -168,10 +168,9 @@ Error TextLLMRunner::generate( prefill_next_token_.reset(); } - // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract pos_ for max_new_tokens. - int max_new_tokens = - config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + // Resolve max_new_tokens. pos_ now reflects all occupied positions + // (including prompt tokens just prefilled). + int max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_); ET_LOG( Info,