Skip to content

Commit 174d3ad

Browse files
authored
Add LogitProcessor interface for pre-sampling logit transforms (pytorch#19517)
Differential Revision: D104767967 Pull Request resolved: pytorch#19517
1 parent d8e4ffd commit 174d3ad

7 files changed

Lines changed: 421 additions & 0 deletions

File tree

extension/llm/runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def define_common_targets():
6868
visibility = ["PUBLIC"],
6969
exported_deps = [
7070
":text_decoder_runner" + aten_suffix,
71+
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
7172
"//pytorch/tokenizers:headers",
7273
"//executorch/extension/module:module" + aten_suffix,
7374
"//executorch/extension/tensor:tensor" + aten_suffix,

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
#include <executorch/extension/llm/runner/text_llm_runner.h>
1313
#include <executorch/extension/llm/runner/text_prefiller.h>
1414
#include <executorch/extension/llm/runner/text_token_generator.h>
15+
#include <executorch/extension/llm/sampler/logit_processor.h>
1516
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1617
#include <gmock/gmock.h>
1718
#include <gtest/gtest.h>
1819

20+
#include <limits>
21+
1922
using namespace ::testing;
2023
using executorch::extension::llm::GenerationConfig;
24+
using executorch::extension::llm::LogitProcessor;
2125
using executorch::extension::llm::Stats;
2226
using executorch::extension::llm::TextDecoderRunner;
2327
using executorch::extension::llm::TextLLMRunner;
@@ -97,6 +101,29 @@ class MockTextPrefiller : public TextPrefiller {
97101
MOCK_METHOD(bool, is_loaded, (), ());
98102
};
99103

104+
class MaskTokenProcessor : public LogitProcessor {
105+
public:
106+
explicit MaskTokenProcessor(int32_t banned_token)
107+
: banned_token_(banned_token) {}
108+
109+
::executorch::runtime::Error process(
110+
::executorch::aten::Tensor logits) override {
111+
const int32_t vocab_size = logits.size(logits.dim() - 1);
112+
int32_t offset = 0;
113+
if (logits.dim() == 3) {
114+
offset = (logits.size(1) - 1) * vocab_size;
115+
}
116+
float* data = logits.mutable_data_ptr<float>();
117+
if (banned_token_ >= 0 && banned_token_ < vocab_size) {
118+
data[offset + banned_token_] = -std::numeric_limits<float>::infinity();
119+
}
120+
return ::executorch::runtime::Error::Ok;
121+
}
122+
123+
private:
124+
int32_t banned_token_;
125+
};
126+
100127
// Callback counter class for tests
101128
class CallbackCounter {
102129
public:
@@ -618,4 +645,95 @@ TEST_F(RunnerTest, MultiTurnWithSeqLenRespectsPos) {
618645
EXPECT_EQ(counter.getCount(), 10);
619646
}
620647

648+
// Verify that a LogitProcessor injected into TextTokenGenerator actually
649+
// affects token selection. Without the processor, greedy argmax of
650+
// {0.1, 0.2, 0.3, 0.4} picks token 3. Masking token 3 should pick token 2.
651+
TEST_F(RunnerTest, TextTokenGeneratorWithProcessorMasksToken) {
652+
auto tokenizer = createMockTokenizer();
653+
auto text_decoder_runner = createMockTextDecoderRunner();
654+
Stats stats;
655+
auto generator = createTextTokenGenerator(
656+
tokenizer.get(), text_decoder_runner.get(), &stats);
657+
658+
generator->add_logit_processor(
659+
std::make_shared<MaskTokenProcessor>(/*banned_token=*/3));
660+
661+
std::vector<uint64_t> generated_tokens;
662+
ON_CALL(*tokenizer, decode)
663+
.WillByDefault(
664+
[&](uint64_t,
665+
uint64_t cur,
666+
bool) -> ::tokenizers::Result<std::string> {
667+
generated_tokens.push_back(cur);
668+
return ::tokenizers::Result<std::string>(std::string("token"));
669+
});
670+
671+
std::vector<uint64_t> tokens = {1, 2, 3};
672+
auto result =
673+
generator->generate(tokens, 3, 3, 0.0f, [](const std::string&) {});
674+
675+
EXPECT_TRUE(result.ok());
676+
const std::vector<uint64_t> expected(3, 2);
677+
EXPECT_EQ(generated_tokens, expected);
678+
}
679+
680+
// Multiple processors in chain should all take effect.
681+
TEST_F(RunnerTest, TextTokenGeneratorProcessorChainMasksMultipleTokens) {
682+
auto tokenizer = createMockTokenizer();
683+
auto text_decoder_runner = createMockTextDecoderRunner();
684+
Stats stats;
685+
auto generator = createTextTokenGenerator(
686+
tokenizer.get(), text_decoder_runner.get(), &stats);
687+
688+
generator->add_logit_processor(
689+
std::make_shared<MaskTokenProcessor>(/*banned_token=*/3));
690+
generator->add_logit_processor(
691+
std::make_shared<MaskTokenProcessor>(/*banned_token=*/2));
692+
693+
std::vector<uint64_t> generated_tokens;
694+
ON_CALL(*tokenizer, decode)
695+
.WillByDefault(
696+
[&](uint64_t,
697+
uint64_t cur,
698+
bool) -> ::tokenizers::Result<std::string> {
699+
generated_tokens.push_back(cur);
700+
return ::tokenizers::Result<std::string>(std::string("token"));
701+
});
702+
703+
std::vector<uint64_t> tokens = {1, 2, 3};
704+
auto result =
705+
generator->generate(tokens, 3, 3, 0.0f, [](const std::string&) {});
706+
707+
EXPECT_TRUE(result.ok());
708+
const std::vector<uint64_t> expected(3, 1);
709+
EXPECT_EQ(generated_tokens, expected);
710+
}
711+
712+
// Without any processors, greedy argmax picks token 3 (zero-overhead path).
713+
TEST_F(RunnerTest, TextTokenGeneratorWithoutProcessorPicksArgmax) {
714+
auto tokenizer = createMockTokenizer();
715+
auto text_decoder_runner = createMockTextDecoderRunner();
716+
Stats stats;
717+
auto generator = createTextTokenGenerator(
718+
tokenizer.get(), text_decoder_runner.get(), &stats);
719+
720+
std::vector<uint64_t> generated_tokens;
721+
ON_CALL(*tokenizer, decode)
722+
.WillByDefault(
723+
[&](uint64_t,
724+
uint64_t cur,
725+
bool) -> ::tokenizers::Result<std::string> {
726+
generated_tokens.push_back(cur);
727+
return ::tokenizers::Result<std::string>(std::string("token"));
728+
});
729+
730+
std::vector<uint64_t> tokens = {1, 2, 3};
731+
auto result =
732+
generator->generate(tokens, 3, 3, 0.0f, [](const std::string&) {});
733+
734+
EXPECT_TRUE(result.ok());
735+
const std::vector<uint64_t> expected(3, 3);
736+
EXPECT_EQ(generated_tokens, expected);
737+
}
738+
621739
} // namespace

extension/llm/runner/text_token_generator.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
#pragma once
1111

1212
#include <atomic>
13+
#include <memory>
14+
#include <vector>
1315

1416
#include <executorch/extension/llm/runner/stats.h>
1517
#include <executorch/extension/llm/runner/text_decoder_runner.h>
18+
#include <executorch/extension/llm/sampler/logit_processor.h>
1619
#include <executorch/extension/tensor/tensor.h>
1720
#include <pytorch/tokenizers/tokenizer.h>
1821

@@ -38,6 +41,20 @@ class ET_EXPERIMENTAL TextTokenGenerator {
3841
ignore_eos_ = ignore_eos;
3942
}
4043

44+
void add_logit_processor(std::shared_ptr<LogitProcessor> processor) {
45+
if (processor) {
46+
logit_processors_.push_back(std::move(processor));
47+
}
48+
}
49+
50+
void clear_logit_processors() {
51+
logit_processors_.clear();
52+
}
53+
54+
size_t num_logit_processors() const {
55+
return logit_processors_.size();
56+
}
57+
4158
virtual ~TextTokenGenerator() = default;
4259

4360
/**
@@ -109,6 +126,10 @@ class ET_EXPERIMENTAL TextTokenGenerator {
109126

110127
prev_token = cur_token;
111128

129+
for (auto& processor : logit_processors_) {
130+
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
131+
}
132+
112133
stats_->on_sampling_begin();
113134
cur_token =
114135
text_decoder_runner_->logits_to_token(logits_tensor, temperature);
@@ -189,6 +210,8 @@ class ET_EXPERIMENTAL TextTokenGenerator {
189210
bool use_kv_cache_;
190211
bool ignore_eos_ = false;
191212

213+
std::vector<std::shared_ptr<LogitProcessor>> logit_processors_;
214+
192215
// state machine
193216
std::atomic<bool> should_stop_{false};
194217

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
13+
#include <executorch/runtime/platform/compiler.h>
14+
15+
namespace executorch {
16+
namespace extension {
17+
namespace llm {
18+
19+
/**
20+
* In-place logit transform applied between the model forward pass and the
21+
* sampler. Examples: grammar masks, logit bias, repetition penalty.
22+
*
23+
* `TextTokenGenerator` runs registered processors in order; each sees
24+
* prior processors' edits. Called once per decoded token — keep it cheap.
25+
*
26+
* Tensor contract:
27+
* rank 2 [batch, vocab] — operate on the full last dim
28+
* rank 3 [batch, seq, vocab] — operate on the LAST sequence position
29+
* other ranks — undefined behavior
30+
*
31+
* Implementations dispatch their own dtype (the chain runner neither casts
32+
* nor copies the tensor). Return non-Ok to abort the chain.
33+
*/
34+
class ET_EXPERIMENTAL LogitProcessor {
35+
public:
36+
virtual ~LogitProcessor() = default;
37+
38+
virtual ::executorch::runtime::Error process(
39+
::executorch::aten::Tensor logits) = 0;
40+
};
41+
42+
} // namespace llm
43+
} // namespace extension
44+
} // namespace executorch

extension/llm/sampler/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def define_common_targets():
77
runtime.cxx_library(
88
name = "sampler" + aten_suffix,
99
exported_headers = [
10+
"logit_processor.h",
1011
"sampler.h",
1112
"util.h",
1213
],

extension/llm/sampler/test/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ def define_common_targets():
2222
"//caffe2:torch-cpp",
2323
],
2424
)
25+
26+
runtime.cxx_test(
27+
name = "test_logit_processor",
28+
srcs = [
29+
"test_logit_processor.cpp",
30+
],
31+
deps = [
32+
"//executorch/extension/llm/sampler:sampler",
33+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
34+
],
35+
)

0 commit comments

Comments
 (0)