Skip to content

Commit e068f7e

Browse files
committed
Simplify generate(string) to always delegate to generate(vector)
generate(string) is now a pure wrapper — empty string passes an empty vector, non-empty wraps as MultimodalInput. The "decode from prior prefill" logic moves to generate(vector) when inputs is empty, giving both overloads consistent semantics. This PR was authored with the assistance of Claude.
1 parent 2c44df7 commit e068f7e

1 file changed

Lines changed: 22 additions & 45 deletions

File tree

extension/llm/runner/multimodal_runner.cpp

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -221,52 +221,18 @@ Error MultimodalRunner::generate(
221221
const GenerationConfig& config,
222222
std::function<void(const std::string&)> token_callback,
223223
std::function<void(const Stats&)> stats_callback) {
224+
std::vector<MultimodalInput> inputs;
224225
if (!prompt.empty()) {
225-
std::vector<MultimodalInput> inputs;
226226
inputs.emplace_back(MultimodalInput(prompt));
227-
return generate(inputs, config, token_callback, stats_callback);
228227
}
229-
230-
// Empty prompt: consume prefill_next_token_ and go straight to decode
231-
ET_CHECK_OR_RETURN_ERROR(
232-
prefill_next_token_.has_value(),
233-
InvalidState,
234-
"Empty prompt requires a prior prefill() call");
235-
236-
if (!is_loaded()) {
237-
ET_CHECK_OK_OR_RETURN_ERROR(load());
238-
}
239-
240-
// Wrap the token_callback with print function
241-
std::function<void(const std::string&)> wrapped_callback =
242-
[token_callback, config](const std::string& piece) {
243-
if (!config.warming) {
244-
safe_printf(piece.c_str());
245-
fflush(stdout);
246-
}
247-
if (token_callback) {
248-
token_callback(piece);
249-
}
250-
};
251-
252-
stats_->inference_start_ms = time_in_ms();
253-
254-
uint64_t cur_token = prefill_next_token_.value();
255-
prefill_next_token_.reset();
256-
257-
return decode_from_token(cur_token, config, wrapped_callback, stats_callback);
228+
return generate(inputs, config, token_callback, stats_callback);
258229
}
259230

260231
Error MultimodalRunner::generate(
261232
const std::vector<MultimodalInput>& inputs,
262233
const GenerationConfig& config,
263234
std::function<void(const std::string&)> token_callback,
264235
std::function<void(const Stats&)> stats_callback) {
265-
if (inputs.empty()) {
266-
ET_LOG(Error, "MultimodalInput vector cannot be empty");
267-
return Error::InvalidArgument;
268-
}
269-
270236
if (!is_loaded()) {
271237
ET_CHECK_OK_OR_RETURN_ERROR(load());
272238
}
@@ -295,16 +261,27 @@ Error MultimodalRunner::generate(
295261
// Reset internal state and start inference
296262
stats_->inference_start_ms = time_in_ms();
297263

298-
// Echo the last text input if enabled
299-
if (config.echo && inputs.back().is_text()) {
300-
wrapped_callback(inputs.back().get_text());
301-
}
264+
uint64_t cur_token;
265+
if (!inputs.empty()) {
266+
// Echo the last text input if enabled
267+
if (config.echo && inputs.back().is_text()) {
268+
wrapped_callback(inputs.back().get_text());
269+
}
302270

303-
// Prefill all inputs and get the first decode token
304-
auto prefill_result = prefill(inputs, config.num_bos, config.num_eos);
305-
ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error());
306-
uint64_t cur_token = prefill_result.get();
307-
prefill_next_token_.reset();
271+
// Prefill all inputs and get the first decode token
272+
auto prefill_result = prefill(inputs, config.num_bos, config.num_eos);
273+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error());
274+
cur_token = prefill_result.get();
275+
prefill_next_token_.reset();
276+
} else {
277+
// Empty inputs: consume token from a prior prefill() call
278+
ET_CHECK_OR_RETURN_ERROR(
279+
prefill_next_token_.has_value(),
280+
InvalidState,
281+
"Empty inputs requires a prior prefill() call");
282+
cur_token = prefill_next_token_.value();
283+
prefill_next_token_.reset();
284+
}
308285

309286
return decode_from_token(cur_token, config, wrapped_callback, stats_callback);
310287
}

0 commit comments

Comments
 (0)