@@ -357,4 +357,145 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
357357 EXPECT_TRUE (runner.is_loaded ());
358358}
359359
360+ // Test that prefill() returns the predicted next token
361+ TEST_F (RunnerTest, PrefillReturnsNextToken) {
362+ auto tokenizer = createMockTokenizer ();
363+ auto text_decoder_runner = createMockTextDecoderRunner ();
364+ auto text_prefiller = createMockTextPrefiller (text_decoder_runner.get ());
365+
366+ ON_CALL (*tokenizer, encode (_, _, _))
367+ .WillByDefault ([&](const std::string&, int8_t , int8_t ) {
368+ return ::tokenizers::Result<std::vector<uint64_t >>(
369+ std::vector<uint64_t >{1 , 2 , 3 });
370+ });
371+
372+ ON_CALL (*text_prefiller, prefill (_, _))
373+ .WillByDefault ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
374+ pos += tokens.size ();
375+ return Result<uint64_t >(42 );
376+ });
377+
378+ ON_CALL (*text_prefiller, is_loaded ()).WillByDefault (Return (true ));
379+
380+ std::unique_ptr<executorch::llm::Stats> stats =
381+ std::make_unique<executorch::llm::Stats>();
382+ auto text_token_generator = createTextTokenGenerator (
383+ tokenizer.get (), text_decoder_runner.get (), stats.get ());
384+
385+ auto module = std::make_unique<MockModule>();
386+ auto io_manager =
387+ std::make_unique<executorch::extension::llm::IOManager>(*module );
388+ TextLLMRunner runner (
389+ createDefaultMetadata (),
390+ std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release ()),
391+ std::move (module ),
392+ std::move (text_decoder_runner),
393+ std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
394+ text_prefiller.release ()),
395+ std::move (io_manager),
396+ std::move (text_token_generator),
397+ std::move (stats));
398+
399+ runner.load ();
400+
401+ auto result = runner.prefill (" system prompt" , 1 , 0 );
402+ EXPECT_TRUE (result.ok ());
403+ EXPECT_EQ (result.get (), 42 );
404+ }
405+
406+ // Test the prefill() → generate("") workflow
407+ TEST_F (RunnerTest, PrefillThenGenerateEmpty) {
408+ auto tokenizer = createMockTokenizer ();
409+ auto text_decoder_runner = createMockTextDecoderRunner ();
410+ auto text_prefiller = createMockTextPrefiller (text_decoder_runner.get ());
411+
412+ ON_CALL (*tokenizer, encode (_, _, _))
413+ .WillByDefault ([&](const std::string&, int8_t , int8_t ) {
414+ return ::tokenizers::Result<std::vector<uint64_t >>(
415+ std::vector<uint64_t >{1 , 2 , 3 });
416+ });
417+
418+ ON_CALL (*text_prefiller, prefill (_, _))
419+ .WillByDefault ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
420+ pos += tokens.size ();
421+ return Result<uint64_t >(4 );
422+ });
423+
424+ ON_CALL (*text_prefiller, is_loaded ()).WillByDefault (Return (true ));
425+
426+ std::unique_ptr<executorch::llm::Stats> stats =
427+ std::make_unique<executorch::llm::Stats>();
428+ auto text_token_generator = createTextTokenGenerator (
429+ tokenizer.get (), text_decoder_runner.get (), stats.get ());
430+
431+ auto module = std::make_unique<MockModule>();
432+ auto io_manager =
433+ std::make_unique<executorch::extension::llm::IOManager>(*module );
434+ TextLLMRunner runner (
435+ createDefaultMetadata (),
436+ std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release ()),
437+ std::move (module ),
438+ std::move (text_decoder_runner),
439+ std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
440+ text_prefiller.release ()),
441+ std::move (io_manager),
442+ std::move (text_token_generator),
443+ std::move (stats));
444+
445+ runner.load ();
446+
447+ // Prefill first
448+ auto prefill_result = runner.prefill (" system prompt" , 1 , 0 );
449+ EXPECT_TRUE (prefill_result.ok ());
450+
451+ // Generate with empty prompt — should consume prefill_next_token_
452+ GenerationConfig config;
453+ config.max_new_tokens = 5 ;
454+ config.echo = false ;
455+
456+ CallbackCounter counter;
457+ Error err = runner.generate (
458+ " " , config, [&counter](const std::string& token) {
459+ counter.callback (token);
460+ });
461+
462+ EXPECT_EQ (err, Error::Ok);
463+ // First token from prefill + remaining from decode loop
464+ EXPECT_EQ (counter.getCount (), config.max_new_tokens );
465+ }
466+
467+ // Test that generate("") without prior prefill() returns an error
468+ TEST_F (RunnerTest, GenerateEmptyWithoutPrefillFails) {
469+ auto tokenizer = createMockTokenizer ();
470+ auto text_decoder_runner = createMockTextDecoderRunner ();
471+ auto text_prefiller = createMockTextPrefiller (text_decoder_runner.get ());
472+
473+ ON_CALL (*text_prefiller, is_loaded ()).WillByDefault (Return (true ));
474+
475+ std::unique_ptr<executorch::llm::Stats> stats =
476+ std::make_unique<executorch::llm::Stats>();
477+ auto text_token_generator = createTextTokenGenerator (
478+ tokenizer.get (), text_decoder_runner.get (), stats.get ());
479+
480+ auto module = std::make_unique<MockModule>();
481+ auto io_manager =
482+ std::make_unique<executorch::extension::llm::IOManager>(*module );
483+ TextLLMRunner runner (
484+ createDefaultMetadata (),
485+ std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release ()),
486+ std::move (module ),
487+ std::move (text_decoder_runner),
488+ std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
489+ text_prefiller.release ()),
490+ std::move (io_manager),
491+ std::move (text_token_generator),
492+ std::move (stats));
493+
494+ runner.load ();
495+
496+ GenerationConfig config;
497+ Error err = runner.generate (" " , config);
498+ EXPECT_EQ (err, Error::InvalidState);
499+ }
500+
360501} // namespace
0 commit comments