@@ -549,4 +549,73 @@ TEST_F(RunnerTest, NonKvCacheGenerateCompletesSuccessfully) {
549549 EXPECT_EQ (step_count, max_new_tokens);
550550}
551551
552+ // Test that multi-turn generation with seq_len correctly accounts for pos_.
553+ // Regression test for a bug where max_context_len was pre-adjusted by pos_,
554+ // causing resolve_max_new_tokens to under-count occupied positions when
555+ // seq_len is set.
556+ TEST_F (RunnerTest, MultiTurnWithSeqLenRespectsPos) {
557+ auto tokenizer = createMockTokenizer ();
558+ auto text_decoder_runner = createMockTextDecoderRunner ();
559+ auto text_prefiller = createMockTextPrefiller (text_decoder_runner.get ());
560+
561+ ON_CALL (*tokenizer, encode (_, _, _))
562+ .WillByDefault ([&](const std::string&, int8_t , int8_t ) {
563+ return ::tokenizers::Result<std::vector<uint64_t >>(
564+ std::vector<uint64_t >{1 , 2 , 3 });
565+ });
566+
567+ ON_CALL (*text_prefiller, prefill (_, _))
568+ .WillByDefault ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
569+ pos += tokens.size ();
570+ return Result<uint64_t >(4 );
571+ });
572+
573+ ON_CALL (*text_prefiller, is_loaded ()).WillByDefault (Return (true ));
574+
575+ std::unique_ptr<executorch::llm::Stats> stats =
576+ std::make_unique<executorch::llm::Stats>();
577+ auto text_token_generator = createTextTokenGenerator (
578+ tokenizer.get (), text_decoder_runner.get (), stats.get ());
579+
580+ auto module = std::make_unique<MockModule>();
581+ auto io_manager =
582+ std::make_unique<executorch::extension::llm::IOManager>(*module );
583+ TextLLMRunner runner (
584+ createDefaultMetadata (), // kMaxContextLen = 128
585+ std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release ()),
586+ std::move (module ),
587+ std::move (text_decoder_runner),
588+ std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
589+ text_prefiller.release ()),
590+ std::move (io_manager),
591+ std::move (text_token_generator),
592+ std::move (stats));
593+
594+ runner.load ();
595+
596+ // First turn: advance pos_ to 7 (3 prompt + 4 generated)
597+ GenerationConfig config1;
598+ config1.max_new_tokens = 5 ; // prefill generates 1, loop generates 4
599+ config1.echo = false ;
600+ Error err1 = runner.generate (" first turn" , config1);
601+ EXPECT_EQ (err1, Error::Ok);
602+
603+ // Second turn with seq_len=20: pos_ is now 7, prompt adds 3 more → pos_=10
604+ // Correct max_new_tokens = min(20, 128) - 10 = 10
605+ // Bug would give: min(20, 128-7) - 3 = 17
606+ GenerationConfig config2;
607+ config2.seq_len = 20 ;
608+ config2.echo = false ;
609+
610+ CallbackCounter counter;
611+ Error err2 = runner.generate (
612+ " second turn" , config2, [&counter](const std::string& token) {
613+ counter.callback (token);
614+ });
615+
616+ EXPECT_EQ (err2, Error::Ok);
617+ // With correct pos_ accounting: min(20, 128) - 10 = 10 new tokens
618+ EXPECT_EQ (counter.getCount (), 10 );
619+ }
620+
552621} // namespace
0 commit comments