Skip to content

Commit dd91e4f

Browse files
committed
Change prefill() to return Result<uint64_t>
prefill() now returns the sampled next token directly, making the API more natural. Callers get the token if they want it, and internally it's also stored in prefill_next_token_ for the generate("") workflow. This eliminates the need for the private prefill_and_sample() helper since prefill() itself returns the token — generate(vector) can call prefill() directly. This PR was authored with the assistance of Claude.
1 parent f03b171 commit dd91e4f

6 files changed

Lines changed: 26 additions & 38 deletions

File tree

extension/llm/runner/irunner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <executorch/extension/llm/runner/stats.h>
2020
#include <executorch/runtime/core/error.h>
21+
#include <executorch/runtime/core/result.h>
2122

2223
namespace executorch {
2324
namespace extension {
@@ -140,7 +141,7 @@ class ET_EXPERIMENTAL IRunner {
140141
* @param num_eos Number of EOS tokens to append during encoding
141142
* @return Error::Ok if successful, an error otherwise
142143
*/
143-
virtual runtime::Error prefill(
144+
virtual runtime::Result<uint64_t> prefill(
144145
const std::vector<MultimodalInput>& inputs,
145146
int32_t num_bos = 0,
146147
int32_t num_eos = 0) {

extension/llm/runner/multimodal_runner.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ Error MultimodalRunner::load() {
8585
ET_LOG(Info, format, __VA_ARGS__); \
8686
}
8787

88-
Result<uint64_t> MultimodalRunner::prefill_and_sample(
88+
Result<uint64_t> MultimodalRunner::prefill(
8989
const std::vector<MultimodalInput>& inputs,
9090
int32_t num_bos,
9191
int32_t num_eos) {
92+
if (!is_loaded()) {
93+
ET_CHECK_OK_OR_RETURN_ERROR(load());
94+
}
9295
uint64_t last_token = 0;
9396
for (size_t i = 0; i < inputs.size(); ++i) {
9497
const auto& input = inputs[i];
@@ -119,24 +122,10 @@ Result<uint64_t> MultimodalRunner::prefill_and_sample(
119122
}
120123
last_token = prefill_result.get();
121124
}
125+
prefill_next_token_ = last_token;
122126
return last_token;
123127
}
124128

125-
Error MultimodalRunner::prefill(
126-
const std::vector<MultimodalInput>& inputs,
127-
int32_t num_bos,
128-
int32_t num_eos) {
129-
if (!is_loaded()) {
130-
ET_CHECK_OK_OR_RETURN_ERROR(load());
131-
}
132-
auto result = prefill_and_sample(inputs, num_bos, num_eos);
133-
if (!result.ok()) {
134-
return result.error();
135-
}
136-
prefill_next_token_ = result.get();
137-
return Error::Ok;
138-
}
139-
140129
Error MultimodalRunner::generate(
141130
const std::string& prompt,
142131
const GenerationConfig& config,
@@ -282,8 +271,7 @@ Error MultimodalRunner::generate(
282271
}
283272

284273
// Prefill all inputs and get the first decode token
285-
auto prefill_result =
286-
prefill_and_sample(inputs, config.num_bos, config.num_eos);
274+
auto prefill_result = prefill(inputs, config.num_bos, config.num_eos);
287275
ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error());
288276
uint64_t cur_token = prefill_result.get();
289277

extension/llm/runner/multimodal_runner.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,18 @@ class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
147147
* text.
148148
* @param num_bos Number of BOS tokens to prepend during encoding.
149149
* @param num_eos Number of EOS tokens to append during encoding.
150-
* @return The error code. KV cache position is tracked internally in pos_.
150+
* @return The next token predicted after prefill, or an error.
151+
* KV cache position is tracked internally in pos_.
151152
*/
152-
::executorch::runtime::Error prefill(
153+
::executorch::runtime::Result<uint64_t> prefill(
153154
const std::vector<MultimodalInput>& inputs,
154155
int32_t num_bos = 0,
155156
int32_t num_eos = 0) override;
156157

157158
/**
158159
* Convenience overload: prefill a single text prompt.
159160
*/
160-
::executorch::runtime::Error
161+
::executorch::runtime::Result<uint64_t>
161162
prefill(const std::string& prompt, int32_t num_bos = 0, int32_t num_eos = 0) {
162163
std::vector<MultimodalInput> inputs;
163164
inputs.emplace_back(MultimodalInput(prompt));
@@ -195,12 +196,6 @@ class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
195196
// Internal state
196197
std::optional<uint64_t> prefill_next_token_;
197198
int64_t pos_;
198-
199-
private:
200-
::executorch::runtime::Result<uint64_t> prefill_and_sample(
201-
const std::vector<MultimodalInput>& inputs,
202-
int32_t num_bos,
203-
int32_t num_eos);
204199
};
205200

206201
} // namespace llm

extension/llm/runner/pybindings.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ class PyTextLLMRunner {
121121
}
122122
{
123123
py::gil_scoped_release release;
124-
Error error = runner_->prefill(prompt, config.num_bos, config.num_eos);
125-
THROW_IF_ERROR(error, "Prefill failed");
124+
auto result = runner_->prefill(prompt, config.num_bos, config.num_eos);
125+
THROW_IF_ERROR(result.error(), "Prefill failed");
126126
}
127127
}
128128

@@ -234,8 +234,8 @@ class PyMultimodalRunner {
234234
}
235235
{
236236
py::gil_scoped_release release;
237-
Error error = runner_->prefill(inputs, 0, 0);
238-
THROW_IF_ERROR(error, "Prefill failed");
237+
auto result = runner_->prefill(inputs, 0, 0);
238+
THROW_IF_ERROR(result.error(), "Prefill failed");
239239
}
240240
}
241241

extension/llm/runner/text_llm_runner.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ Error TextLLMRunner::generate(
257257
return Error::Ok;
258258
}
259259

260-
Error TextLLMRunner::prefill(
260+
Result<uint64_t> TextLLMRunner::prefill(
261261
const std::vector<MultimodalInput>& inputs,
262262
int32_t num_bos,
263263
int32_t num_eos) {
@@ -283,7 +283,10 @@ Error TextLLMRunner::prefill(
283283
// Skip non-text inputs — text-only runner
284284
}
285285

286-
return Error::Ok;
286+
if (!prefill_next_token_.has_value()) {
287+
return Error::InvalidArgument;
288+
}
289+
return prefill_next_token_.value();
287290
}
288291

289292
Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {

extension/llm/runner/text_llm_runner.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,18 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
109109
* @param inputs A vector of MultimodalInput objects.
110110
* @param num_bos Number of BOS tokens to prepend during text encoding.
111111
* @param num_eos Number of EOS tokens to append during text encoding.
112-
* @return The error code. KV cache position is tracked internally in pos_.
112+
* @return The next token predicted after prefill, or an error.
113+
* KV cache position is tracked internally in pos_.
113114
*/
114-
::executorch::runtime::Error prefill(
115+
::executorch::runtime::Result<uint64_t> prefill(
115116
const std::vector<MultimodalInput>& inputs,
116117
int32_t num_bos = 0,
117118
int32_t num_eos = 0) override;
118119

119120
/**
120121
* Convenience overload: prefill a single text prompt.
121122
*/
122-
::executorch::runtime::Error
123+
::executorch::runtime::Result<uint64_t>
123124
prefill(const std::string& prompt, int32_t num_bos = 0, int32_t num_eos = 0) {
124125
std::vector<MultimodalInput> inputs;
125126
inputs.emplace_back(MultimodalInput(prompt));
@@ -130,7 +131,7 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
130131
* Prefill a text prompt using GenerationConfig.
131132
* Deprecated: prefer prefill(prompt, num_bos, num_eos).
132133
*/
133-
::executorch::runtime::Error prefill(
134+
::executorch::runtime::Result<uint64_t> prefill(
134135
const std::string& prompt,
135136
const GenerationConfig& config) {
136137
return prefill(prompt, config.num_bos, config.num_eos);

0 commit comments

Comments
 (0)