Skip to content

Commit 67bc28b

Browse files
authored
Unify MultimodalRunner under IRunner with multimodal prefill (#17741)
Make `MultimodalRunner` inherit from `IRunner` so callers can hold a single `IRunner*` regardless of model type. Add `prefill(vector<MultimodalInput>, num_bos, num_eos)` to `IRunner` returning `Result<uint64_t>` (the predicted next token), with a default `NotSupported` implementation. Resolves #17728 ## Changes **`irunner.h`** — Forward-declare `MultimodalInput`. Add virtual `prefill(vector<MultimodalInput>)` returning `Result<uint64_t>` with default `NotSupported`. **`multimodal_runner.h/cpp`** — `MultimodalRunner` now inherits `IRunner`. `generate(string)` override is a pure wrapper that delegates to `generate(vector)`. `generate(vector)` handles both non-empty inputs (prefill + decode) and empty inputs (consume `prefill_next_token_` from a prior `prefill()` call). Decode loop extracted into private `decode_from_token()` to avoid duplication. `is_loaded()` becomes `const override`, `stop()`/`reset()`/`load()` gain `override`. String convenience `prefill(string)` provided inline. **`text_llm_runner.h/cpp`** — New `prefill(vector<MultimodalInput>)` override handles text inputs (encode + prefill KV cache), returns predicted next token. `generate("")` allowed after `prefill()` — consumes stored `prefill_next_token_`. Old `prefill(string, GenerationConfig)` preserved as deprecated wrapper. String convenience methods defined in .cpp to avoid header dependency on `multimodal_input.h`. **`pybindings.cpp`** — Adapts to `Result<uint64_t>` return type from `prefill()`. **`_llm_runner.pyi`** — Updated `MultimodalRunner.prefill` docstring. ## Design decisions - `prefill()` returns `Result<uint64_t>` — the sampled next token from the final forward pass. This is stored internally in `prefill_next_token_` for the `prefill()` → `generate("")`/`generate({})` workflow, and also returned to callers who may want the token directly. - `prefill()` takes `num_bos`/`num_eos` instead of `GenerationConfig` — those are the only fields relevant to prefill (for tokenizer encoding). - BOS is only applied when `pos_ == 0` (start of conversation) in `MultimodalRunner`. `TextLLMRunner` trusts the caller's `num_bos` value. - `generate(string)` is always a pure wrapper to `generate(vector)` in `MultimodalRunner` — empty string passes an empty vector, non-empty wraps as `MultimodalInput`. - `text_llm_runner.h` avoids `#include multimodal_input.h` — only the forward declaration from `irunner.h` is needed in the header. String convenience methods are defined in the .cpp. ## Backward compatibility - **C++ source-compatible**: `MultimodalRunner::prefill(inputs)` still compiles (new params have defaults). `TextLLMRunner::prefill(string, GenerationConfig)` preserved as deprecated wrapper. Return type changed from `Error` to `Result<uint64_t>` — callers checking `.error()` still work. - **ABI-breaking**: Expected for `ET_EXPERIMENTAL` APIs. - **Python**: Fully compatible, interface signatures unchanged. ## Test plan - Existing C++ tests: `test_text_llm_runner.cpp`, `test_text_prefiller.cpp`, `test_generation_config.cpp` — verify no regressions - Build: `cmake --build` for `extension_llm_runner` target
1 parent 4528ae2 commit 67bc28b

8 files changed

Lines changed: 453 additions & 149 deletions

File tree

extension/llm/runner/_llm_runner.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ class MultimodalRunner:
479479
def prefill(self, inputs: List[MultimodalInput]) -> None:
480480
"""
481481
Prefill multimodal inputs (e.g., to rebuild KV cache from chat history)
482-
without generating tokens.
482+
without generating tokens. After prefill, call generate() with a
483+
non-empty final text input to start decoding.
483484
484485
Args:
485486
inputs: List of multimodal inputs to prefill

extension/llm/runner/irunner.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414
#include <functional>
1515
#include <memory>
1616
#include <string>
17+
#include <vector>
1718

1819
#include <executorch/extension/llm/runner/stats.h>
1920
#include <executorch/runtime/core/error.h>
21+
#include <executorch/runtime/core/result.h>
2022

2123
namespace executorch {
2224
namespace extension {
2325
namespace llm {
2426

27+
class MultimodalInput; // Forward declaration
28+
2529
// Configuration struct for generation parameters, fields should be sorted in
2630
// alphabetic order
2731
struct GenerationConfig {
@@ -128,6 +132,22 @@ class ET_EXPERIMENTAL IRunner {
128132
std::function<void(const std::string&)> token_callback,
129133
std::function<void(const Stats&)> stats_callback) = 0;
130134

135+
/**
136+
* Prefill multimodal inputs into the KV cache without generating.
137+
*
138+
* @param inputs A vector of MultimodalInput objects (text, tokens, images,
139+
* audio)
140+
* @param num_bos Number of BOS tokens to prepend during encoding
141+
* @param num_eos Number of EOS tokens to append during encoding
142+
* @return The next token predicted after prefill, or an error
143+
*/
144+
virtual runtime::Result<uint64_t> prefill(
145+
const std::vector<MultimodalInput>& inputs,
146+
int32_t num_bos = 0,
147+
int32_t num_eos = 0) {
148+
return runtime::Error::NotSupported;
149+
}
150+
131151
/**
132152
* Stop the generation process.
133153
*/

extension/llm/runner/multimodal_runner.cpp

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ MultimodalRunner::MultimodalRunner(
5353
#endif
5454
}
5555

56-
bool MultimodalRunner::is_loaded() {
56+
bool MultimodalRunner::is_loaded() const {
5757
return multimodal_prefiller_->is_method_loaded() &&
5858
text_token_generator_->is_loaded();
5959
}
@@ -85,89 +85,57 @@ Error MultimodalRunner::load() {
8585
ET_LOG(Info, format, __VA_ARGS__); \
8686
}
8787

88-
Error MultimodalRunner::prefill(const std::vector<MultimodalInput>& inputs) {
89-
if (!is_loaded()) {
90-
ET_CHECK_OK_OR_RETURN_ERROR(load());
91-
}
92-
for (auto& input : inputs) {
93-
auto prefill_result = multimodal_prefiller_->prefill(input, pos_);
94-
if (!prefill_result.ok()) {
95-
return prefill_result.error();
96-
}
97-
}
98-
return Error::Ok;
99-
}
100-
101-
Error MultimodalRunner::generate(
88+
Result<uint64_t> MultimodalRunner::prefill(
10289
const std::vector<MultimodalInput>& inputs,
103-
const GenerationConfig& config,
104-
std::function<void(const std::string&)> token_callback,
105-
std::function<void(const Stats&)> stats_callback) {
106-
if (inputs.empty()) {
107-
ET_LOG(Error, "MultimodalInput vector cannot be empty");
108-
return Error::InvalidArgument;
109-
}
110-
90+
int32_t num_bos,
91+
int32_t num_eos) {
11192
if (!is_loaded()) {
11293
ET_CHECK_OK_OR_RETURN_ERROR(load());
11394
}
114-
115-
if (config.warming) {
116-
ET_LOG(Info, "Doing a warmup run...");
117-
}
118-
119-
RUNNER_ET_LOG(
120-
config.warming,
121-
"RSS after loading model: %f MiB (0 if unsupported)",
122-
get_rss_bytes() / 1024.0 / 1024.0);
123-
124-
// Wrap the token_callback with print function
125-
std::function<void(const std::string&)> wrapped_callback =
126-
[token_callback, config](const std::string& piece) {
127-
if (!config.warming) {
128-
safe_printf(piece.c_str());
129-
fflush(stdout);
130-
}
131-
if (token_callback) {
132-
token_callback(piece);
133-
}
134-
};
135-
136-
// Reset internal state and start inference
137-
stats_->inference_start_ms = time_in_ms();
138-
139-
uint64_t prefill_next_token = 0;
140-
// Process multimodal inputs in order
95+
uint64_t last_token = 0;
14196
for (size_t i = 0; i < inputs.size(); ++i) {
142-
const MultimodalInput& input = inputs[i];
143-
ET_LOG(
144-
Info,
145-
"Prefilling input %zu/%zu, type: %s",
146-
i,
147-
inputs.size(),
148-
input.type_name());
149-
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
150-
wrapped_callback(input.get_text());
151-
}
97+
const auto& input = inputs[i];
15298
int32_t bos = 0;
15399
int32_t eos = 0;
154-
if (i == 0 && input.is_text()) {
155-
bos = config.num_bos;
156-
eos = config.num_eos;
100+
if (i == 0 && pos_ == 0) {
101+
if (input.is_text() || input.is_tokens()) {
102+
bos = num_bos;
103+
eos = num_eos;
104+
} else if (num_bos > 0) {
105+
// Non-text first input: prepend BOS via a token input
106+
auto it = metadata_.find(kBosId);
107+
if (it != metadata_.end()) {
108+
std::vector<uint64_t> bos_tokens(
109+
num_bos, static_cast<uint64_t>(it->second));
110+
MultimodalInput bos_input(std::move(bos_tokens));
111+
auto bos_result = multimodal_prefiller_->prefill(bos_input, pos_);
112+
if (!bos_result.ok()) {
113+
return bos_result.error();
114+
}
115+
last_token = bos_result.get();
116+
}
117+
}
157118
}
158119
auto prefill_result = multimodal_prefiller_->prefill(input, pos_, bos, eos);
159120
if (!prefill_result.ok()) {
160121
return prefill_result.error();
161122
}
162-
prefill_next_token = prefill_result.get();
123+
last_token = prefill_result.get();
163124
}
125+
prefill_next_token_ = last_token;
126+
return last_token;
127+
}
164128

129+
Error MultimodalRunner::decode_from_token(
130+
uint64_t cur_token,
131+
const GenerationConfig& config,
132+
std::function<void(const std::string&)> wrapped_callback,
133+
std::function<void(const Stats&)> stats_callback) {
165134
stats_->first_token_ms = time_in_ms();
166135
stats_->prompt_eval_end_ms = time_in_ms();
167136
stats_->num_prompt_tokens = pos_;
168137

169-
auto decode_result =
170-
tokenizer_->decode(prefill_next_token, prefill_next_token);
138+
auto decode_result = tokenizer_->decode(cur_token, cur_token);
171139
if (!decode_result.ok()) {
172140
ET_LOG(
173141
Error,
@@ -183,8 +151,7 @@ Error MultimodalRunner::generate(
183151
get_rss_bytes() / 1024.0 / 1024.0);
184152

185153
// Resolve max_new_tokens based on config
186-
int64_t max_context_len =
187-
metadata_.at(kMaxContextLen) - 0; // No start_pos offset
154+
int64_t max_context_len = metadata_.at(kMaxContextLen);
188155
int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
189156

190157
ET_LOG(
@@ -204,7 +171,7 @@ Error MultimodalRunner::generate(
204171
text_token_generator_->set_ignore_eos(config.ignore_eos);
205172

206173
// Generate tokens using the text token generator
207-
std::vector<uint64_t> prompt_tokens = {prefill_next_token};
174+
std::vector<uint64_t> prompt_tokens = {cur_token};
208175
auto generate_result = text_token_generator_->generate(
209176
/*tokens=*/prompt_tokens,
210177
/*start_pos=*/pos_,
@@ -249,4 +216,74 @@ Error MultimodalRunner::generate(
249216
return Error::Ok;
250217
}
251218

219+
Error MultimodalRunner::generate(
220+
const std::string& prompt,
221+
const GenerationConfig& config,
222+
std::function<void(const std::string&)> token_callback,
223+
std::function<void(const Stats&)> stats_callback) {
224+
std::vector<MultimodalInput> inputs;
225+
if (!prompt.empty()) {
226+
inputs.emplace_back(MultimodalInput(prompt));
227+
}
228+
return generate(inputs, config, token_callback, stats_callback);
229+
}
230+
231+
Error MultimodalRunner::generate(
232+
const std::vector<MultimodalInput>& inputs,
233+
const GenerationConfig& config,
234+
std::function<void(const std::string&)> token_callback,
235+
std::function<void(const Stats&)> stats_callback) {
236+
if (!is_loaded()) {
237+
ET_CHECK_OK_OR_RETURN_ERROR(load());
238+
}
239+
240+
if (config.warming) {
241+
ET_LOG(Info, "Doing a warmup run...");
242+
}
243+
244+
RUNNER_ET_LOG(
245+
config.warming,
246+
"RSS after loading model: %f MiB (0 if unsupported)",
247+
get_rss_bytes() / 1024.0 / 1024.0);
248+
249+
// Wrap the token_callback with print function
250+
std::function<void(const std::string&)> wrapped_callback =
251+
[token_callback, config](const std::string& piece) {
252+
if (!config.warming) {
253+
safe_printf(piece.c_str());
254+
fflush(stdout);
255+
}
256+
if (token_callback) {
257+
token_callback(piece);
258+
}
259+
};
260+
261+
// Reset internal state and start inference
262+
stats_->inference_start_ms = time_in_ms();
263+
264+
uint64_t cur_token = 0;
265+
if (!inputs.empty()) {
266+
// Echo the last text input if enabled
267+
if (config.echo && inputs.back().is_text()) {
268+
wrapped_callback(inputs.back().get_text());
269+
}
270+
271+
// Prefill all inputs and get the first decode token
272+
auto prefill_result = prefill(inputs, config.num_bos, config.num_eos);
273+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error());
274+
cur_token = prefill_result.get();
275+
prefill_next_token_.reset();
276+
} else {
277+
// Empty inputs: consume token from a prior prefill() call
278+
ET_CHECK_OR_RETURN_ERROR(
279+
prefill_next_token_.has_value(),
280+
InvalidState,
281+
"Empty inputs requires a prior prefill() call");
282+
cur_token = prefill_next_token_.value();
283+
prefill_next_token_.reset();
284+
}
285+
286+
return decode_from_token(cur_token, config, wrapped_callback, stats_callback);
287+
}
288+
252289
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_runner.h

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cstdint>
1616
#include <functional>
1717
#include <memory>
18+
#include <optional>
1819
#include <string>
1920
#include <unordered_map>
2021

@@ -74,7 +75,7 @@ namespace llm {
7475
*
7576
* runner->generate(inputs, config, token_callback, stats_callback);
7677
*/
77-
class ET_EXPERIMENTAL MultimodalRunner {
78+
class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
7879
public:
7980
/**
8081
* @brief Constructor for MultimodalRunner with dependency injection
@@ -105,8 +106,24 @@ class ET_EXPERIMENTAL MultimodalRunner {
105106
std::unique_ptr<TextTokenGenerator> text_token_generator,
106107
std::unique_ptr<Stats> stats);
107108

108-
virtual bool is_loaded();
109-
virtual ::executorch::runtime::Error load();
109+
bool is_loaded() const override;
110+
::executorch::runtime::Error load() override;
111+
112+
/**
113+
* Generate tokens from a text prompt. Wraps the prompt as a MultimodalInput
114+
* and delegates to generate(vector). Empty prompt is allowed if prefill()
115+
* was called beforehand.
116+
* @param prompt The text prompt to generate from.
117+
* @param config Generation configuration parameters.
118+
* @param token_callback Callback function called for each generated token.
119+
* @param stats_callback Callback function for generation statistics.
120+
* @return The error code. KV cache position is tracked internally in pos_.
121+
*/
122+
::executorch::runtime::Error generate(
123+
const std::string& prompt,
124+
const GenerationConfig& config,
125+
std::function<void(const std::string&)> token_callback = {},
126+
std::function<void(const Stats&)> stats_callback = {}) override;
110127

111128
/**
112129
* Generate tokens from the given multimodal inputs using GenerationConfig.
@@ -124,24 +141,42 @@ class ET_EXPERIMENTAL MultimodalRunner {
124141
std::function<void(const Stats&)> stats_callback = {});
125142

126143
/**
127-
* Prefill multimodal inputs, for example to reload chat history.
144+
* Prefill multimodal inputs to fill the KV cache, for example to reload
145+
* chat history. Call generate() with a non-empty prompt afterwards to
146+
* start decoding.
128147
* @param inputs A vector of MultimodalInput objects containing images and
129148
* text.
130-
* @return The error code. KV cache position is tracked internally in pos_.
149+
* @param num_bos Number of BOS tokens to prepend during encoding.
150+
* @param num_eos Number of EOS tokens to append during encoding.
151+
* @return The next token predicted after prefill, or an error.
152+
* KV cache position is tracked internally in pos_.
153+
*/
154+
::executorch::runtime::Result<uint64_t> prefill(
155+
const std::vector<MultimodalInput>& inputs,
156+
int32_t num_bos = 0,
157+
int32_t num_eos = 0) override;
158+
159+
/**
160+
* Convenience overload: prefill a single text prompt.
131161
*/
132-
virtual ::executorch::runtime::Error prefill(
133-
const std::vector<MultimodalInput>& inputs);
162+
::executorch::runtime::Result<uint64_t>
163+
prefill(const std::string& prompt, int32_t num_bos = 0, int32_t num_eos = 0) {
164+
std::vector<MultimodalInput> inputs;
165+
inputs.emplace_back(MultimodalInput(prompt));
166+
return prefill(inputs, num_bos, num_eos);
167+
}
134168

135-
inline void stop() {
169+
void stop() override {
136170
text_token_generator_->stop();
137171
}
138172

139-
inline void reset() {
173+
void reset() override {
140174
pos_ = 0;
141175
stats_->reset();
176+
prefill_next_token_.reset();
142177
}
143178

144-
virtual ~MultimodalRunner() = default;
179+
~MultimodalRunner() override = default;
145180

146181
protected:
147182
// Components
@@ -160,7 +195,15 @@ class ET_EXPERIMENTAL MultimodalRunner {
160195
#endif
161196

162197
// Internal state
198+
std::optional<uint64_t> prefill_next_token_;
163199
int64_t pos_;
200+
201+
private:
202+
::executorch::runtime::Error decode_from_token(
203+
uint64_t cur_token,
204+
const GenerationConfig& config,
205+
std::function<void(const std::string&)> wrapped_callback,
206+
std::function<void(const Stats&)> stats_callback);
164207
};
165208

166209
} // namespace llm

0 commit comments

Comments
 (0)