|
12 | 12 | #include <executorch/extension/llm/runner/text_llm_runner.h> |
13 | 13 | #include <executorch/extension/llm/runner/text_prefiller.h> |
14 | 14 | #include <executorch/extension/llm/runner/text_token_generator.h> |
| 15 | +#include <executorch/extension/llm/sampler/logit_processor.h> |
15 | 16 | #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> |
16 | 17 | #include <gmock/gmock.h> |
17 | 18 | #include <gtest/gtest.h> |
18 | 19 |
|
| 20 | +#include <limits> |
| 21 | + |
19 | 22 | using namespace ::testing; |
20 | 23 | using executorch::extension::llm::GenerationConfig; |
| 24 | +using executorch::extension::llm::LogitProcessor; |
21 | 25 | using executorch::extension::llm::Stats; |
22 | 26 | using executorch::extension::llm::TextDecoderRunner; |
23 | 27 | using executorch::extension::llm::TextLLMRunner; |
@@ -97,6 +101,29 @@ class MockTextPrefiller : public TextPrefiller { |
97 | 101 | MOCK_METHOD(bool, is_loaded, (), ()); |
98 | 102 | }; |
99 | 103 |
|
| 104 | +class MaskTokenProcessor : public LogitProcessor { |
| 105 | + public: |
| 106 | + explicit MaskTokenProcessor(int32_t banned_token) |
| 107 | + : banned_token_(banned_token) {} |
| 108 | + |
| 109 | + ::executorch::runtime::Error process( |
| 110 | + ::executorch::aten::Tensor logits) override { |
| 111 | + const int32_t vocab_size = logits.size(logits.dim() - 1); |
| 112 | + int32_t offset = 0; |
| 113 | + if (logits.dim() == 3) { |
| 114 | + offset = (logits.size(1) - 1) * vocab_size; |
| 115 | + } |
| 116 | + float* data = logits.mutable_data_ptr<float>(); |
| 117 | + if (banned_token_ >= 0 && banned_token_ < vocab_size) { |
| 118 | + data[offset + banned_token_] = -std::numeric_limits<float>::infinity(); |
| 119 | + } |
| 120 | + return ::executorch::runtime::Error::Ok; |
| 121 | + } |
| 122 | + |
| 123 | + private: |
| 124 | + int32_t banned_token_; |
| 125 | +}; |
| 126 | + |
100 | 127 | // Callback counter class for tests |
101 | 128 | class CallbackCounter { |
102 | 129 | public: |
@@ -618,4 +645,95 @@ TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) { |
618 | 645 | EXPECT_EQ(counter.getCount(), 10); |
619 | 646 | } |
620 | 647 |
|
| 648 | +// Verify that a LogitProcessor injected into TextTokenGenerator actually |
| 649 | +// affects token selection. Without the processor, greedy argmax of |
| 650 | +// {0.1, 0.2, 0.3, 0.4} picks token 3. Masking token 3 should pick token 2. |
| 651 | +TEST_F(RunnerTest, TextTokenGeneratorWithProcessorMasksToken) { |
| 652 | + auto tokenizer = createMockTokenizer(); |
| 653 | + auto text_decoder_runner = createMockTextDecoderRunner(); |
| 654 | + Stats stats; |
| 655 | + auto generator = createTextTokenGenerator( |
| 656 | + tokenizer.get(), text_decoder_runner.get(), &stats); |
| 657 | + |
| 658 | + generator->add_logit_processor( |
| 659 | + std::make_shared<MaskTokenProcessor>(/*banned_token=*/3)); |
| 660 | + |
| 661 | + std::vector<uint64_t> generated_tokens; |
| 662 | + ON_CALL(*tokenizer, decode) |
| 663 | + .WillByDefault( |
| 664 | + [&](uint64_t, |
| 665 | + uint64_t cur, |
| 666 | + bool) -> ::tokenizers::Result<std::string> { |
| 667 | + generated_tokens.push_back(cur); |
| 668 | + return ::tokenizers::Result<std::string>(std::string("token")); |
| 669 | + }); |
| 670 | + |
| 671 | + std::vector<uint64_t> tokens = {1, 2, 3}; |
| 672 | + auto result = |
| 673 | + generator->generate(tokens, 3, 3, 0.0f, [](const std::string&) {}); |
| 674 | + |
| 675 | + EXPECT_TRUE(result.ok()); |
| 676 | + const std::vector<uint64_t> expected(3, 2); |
| 677 | + EXPECT_EQ(generated_tokens, expected); |
| 678 | +} |
| 679 | + |
| 680 | +// Multiple processors in chain should all take effect. |
| 681 | +TEST_F(RunnerTest, TextTokenGeneratorProcessorChainMasksMultipleTokens) { |
| 682 | + auto tokenizer = createMockTokenizer(); |
| 683 | + auto text_decoder_runner = createMockTextDecoderRunner(); |
| 684 | + Stats stats; |
| 685 | + auto generator = createTextTokenGenerator( |
| 686 | + tokenizer.get(), text_decoder_runner.get(), &stats); |
| 687 | + |
| 688 | + generator->add_logit_processor( |
| 689 | + std::make_shared<MaskTokenProcessor>(/*banned_token=*/3)); |
| 690 | + generator->add_logit_processor( |
| 691 | + std::make_shared<MaskTokenProcessor>(/*banned_token=*/2)); |
| 692 | + |
| 693 | + std::vector<uint64_t> generated_tokens; |
| 694 | + ON_CALL(*tokenizer, decode) |
| 695 | + .WillByDefault( |
| 696 | + [&](uint64_t, |
| 697 | + uint64_t cur, |
| 698 | + bool) -> ::tokenizers::Result<std::string> { |
| 699 | + generated_tokens.push_back(cur); |
| 700 | + return ::tokenizers::Result<std::string>(std::string("token")); |
| 701 | + }); |
| 702 | + |
| 703 | + std::vector<uint64_t> tokens = {1, 2, 3}; |
| 704 | + auto result = |
| 705 | + generator->generate(tokens, 3, 3, 0.0f, [](const std::string&) {}); |
| 706 | + |
| 707 | + EXPECT_TRUE(result.ok()); |
| 708 | + const std::vector<uint64_t> expected(3, 1); |
| 709 | + EXPECT_EQ(generated_tokens, expected); |
| 710 | +} |
| 711 | + |
| 712 | +// Without any processors, greedy argmax picks token 3 (zero-overhead path). |
| 713 | +TEST_F(RunnerTest, TextTokenGeneratorWithoutProcessorPicksArgmax) { |
| 714 | + auto tokenizer = createMockTokenizer(); |
| 715 | + auto text_decoder_runner = createMockTextDecoderRunner(); |
| 716 | + Stats stats; |
| 717 | + auto generator = createTextTokenGenerator( |
| 718 | + tokenizer.get(), text_decoder_runner.get(), &stats); |
| 719 | + |
| 720 | + std::vector<uint64_t> generated_tokens; |
| 721 | + ON_CALL(*tokenizer, decode) |
| 722 | + .WillByDefault( |
| 723 | + [&](uint64_t, |
| 724 | + uint64_t cur, |
| 725 | + bool) -> ::tokenizers::Result<std::string> { |
| 726 | + generated_tokens.push_back(cur); |
| 727 | + return ::tokenizers::Result<std::string>(std::string("token")); |
| 728 | + }); |
| 729 | + |
| 730 | + std::vector<uint64_t> tokens = {1, 2, 3}; |
| 731 | + auto result = |
| 732 | + generator->generate(tokens, 3, 3, 0.0f, [](const std::string&) {}); |
| 733 | + |
| 734 | + EXPECT_TRUE(result.ok()); |
| 735 | + const std::vector<uint64_t> expected(3, 3); |
| 736 | + EXPECT_EQ(generated_tokens, expected); |
| 737 | +} |
| 738 | + |
621 | 739 | } // namespace |
0 commit comments