Skip to content

Commit 7e944e9

Browse files
committed
Add unit tests for prefill() and prefill-then-generate workflow
Three new tests: - PrefillReturnsNextToken: verifies prefill() returns the predicted next token from the mock text_prefiller - PrefillThenGenerateEmpty: verifies the prefill() → generate("") workflow produces the expected number of tokens - GenerateEmptyWithoutPrefillFails: verifies generate("") without prior prefill() returns InvalidState error This PR was authored with the assistance of Claude.
1 parent e83bd7e commit 7e944e9

1 file changed

Lines changed: 141 additions & 0 deletions

File tree

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,4 +357,145 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
357357
EXPECT_TRUE(runner.is_loaded());
358358
}
359359

360+
// Test that prefill() returns the predicted next token
361+
TEST_F(RunnerTest, PrefillReturnsNextToken) {
362+
auto tokenizer = createMockTokenizer();
363+
auto text_decoder_runner = createMockTextDecoderRunner();
364+
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
365+
366+
ON_CALL(*tokenizer, encode(_, _, _))
367+
.WillByDefault([&](const std::string&, int8_t, int8_t) {
368+
return ::tokenizers::Result<std::vector<uint64_t>>(
369+
std::vector<uint64_t>{1, 2, 3});
370+
});
371+
372+
ON_CALL(*text_prefiller, prefill(_, _))
373+
.WillByDefault([&](std::vector<uint64_t>& tokens, int64_t& pos) {
374+
pos += tokens.size();
375+
return Result<uint64_t>(42);
376+
});
377+
378+
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));
379+
380+
std::unique_ptr<executorch::llm::Stats> stats =
381+
std::make_unique<executorch::llm::Stats>();
382+
auto text_token_generator = createTextTokenGenerator(
383+
tokenizer.get(), text_decoder_runner.get(), stats.get());
384+
385+
auto module = std::make_unique<MockModule>();
386+
auto io_manager =
387+
std::make_unique<executorch::extension::llm::IOManager>(*module);
388+
TextLLMRunner runner(
389+
createDefaultMetadata(),
390+
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
391+
std::move(module),
392+
std::move(text_decoder_runner),
393+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
394+
text_prefiller.release()),
395+
std::move(io_manager),
396+
std::move(text_token_generator),
397+
std::move(stats));
398+
399+
runner.load();
400+
401+
auto result = runner.prefill("system prompt", 1, 0);
402+
EXPECT_TRUE(result.ok());
403+
EXPECT_EQ(result.get(), 42);
404+
}
405+
406+
// Test the prefill() → generate("") workflow
407+
TEST_F(RunnerTest, PrefillThenGenerateEmpty) {
408+
auto tokenizer = createMockTokenizer();
409+
auto text_decoder_runner = createMockTextDecoderRunner();
410+
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
411+
412+
ON_CALL(*tokenizer, encode(_, _, _))
413+
.WillByDefault([&](const std::string&, int8_t, int8_t) {
414+
return ::tokenizers::Result<std::vector<uint64_t>>(
415+
std::vector<uint64_t>{1, 2, 3});
416+
});
417+
418+
ON_CALL(*text_prefiller, prefill(_, _))
419+
.WillByDefault([&](std::vector<uint64_t>& tokens, int64_t& pos) {
420+
pos += tokens.size();
421+
return Result<uint64_t>(4);
422+
});
423+
424+
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));
425+
426+
std::unique_ptr<executorch::llm::Stats> stats =
427+
std::make_unique<executorch::llm::Stats>();
428+
auto text_token_generator = createTextTokenGenerator(
429+
tokenizer.get(), text_decoder_runner.get(), stats.get());
430+
431+
auto module = std::make_unique<MockModule>();
432+
auto io_manager =
433+
std::make_unique<executorch::extension::llm::IOManager>(*module);
434+
TextLLMRunner runner(
435+
createDefaultMetadata(),
436+
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
437+
std::move(module),
438+
std::move(text_decoder_runner),
439+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
440+
text_prefiller.release()),
441+
std::move(io_manager),
442+
std::move(text_token_generator),
443+
std::move(stats));
444+
445+
runner.load();
446+
447+
// Prefill first
448+
auto prefill_result = runner.prefill("system prompt", 1, 0);
449+
EXPECT_TRUE(prefill_result.ok());
450+
451+
// Generate with empty prompt — should consume prefill_next_token_
452+
GenerationConfig config;
453+
config.max_new_tokens = 5;
454+
config.echo = false;
455+
456+
CallbackCounter counter;
457+
Error err = runner.generate(
458+
"", config, [&counter](const std::string& token) {
459+
counter.callback(token);
460+
});
461+
462+
EXPECT_EQ(err, Error::Ok);
463+
// First token from prefill + remaining from decode loop
464+
EXPECT_EQ(counter.getCount(), config.max_new_tokens);
465+
}
466+
467+
// Test that generate("") without prior prefill() returns an error
468+
TEST_F(RunnerTest, GenerateEmptyWithoutPrefillFails) {
469+
auto tokenizer = createMockTokenizer();
470+
auto text_decoder_runner = createMockTextDecoderRunner();
471+
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());
472+
473+
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));
474+
475+
std::unique_ptr<executorch::llm::Stats> stats =
476+
std::make_unique<executorch::llm::Stats>();
477+
auto text_token_generator = createTextTokenGenerator(
478+
tokenizer.get(), text_decoder_runner.get(), stats.get());
479+
480+
auto module = std::make_unique<MockModule>();
481+
auto io_manager =
482+
std::make_unique<executorch::extension::llm::IOManager>(*module);
483+
TextLLMRunner runner(
484+
createDefaultMetadata(),
485+
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
486+
std::move(module),
487+
std::move(text_decoder_runner),
488+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
489+
text_prefiller.release()),
490+
std::move(io_manager),
491+
std::move(text_token_generator),
492+
std::move(stats));
493+
494+
runner.load();
495+
496+
GenerationConfig config;
497+
Error err = runner.generate("", config);
498+
EXPECT_EQ(err, Error::InvalidState);
499+
}
500+
360501
} // namespace

0 commit comments

Comments
 (0)