Skip to content

Commit f03b171

Browse files
committed
Store prefill output token to enable prefill() → generate("") workflow
Both TextLLMRunner and MultimodalRunner now store the sampled next token from prefill() in prefill_next_token_. When generate() is called with an empty prompt, it consumes this token and starts decoding directly without re-prefilling. This enables the workflow: runner->prefill("system prompt", 1, 0); runner->prefill("user turn", 0, 0); runner->generate("", config, callback); // decode from KV cache This PR was authored with the assistance of Claude.
1 parent bcc7616 commit f03b171

4 files changed

Lines changed: 164 additions & 49 deletions

File tree

extension/llm/runner/multimodal_runner.cpp

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Error MultimodalRunner::prefill(
133133
if (!result.ok()) {
134134
return result.error();
135135
}
136+
prefill_next_token_ = result.get();
136137
return Error::Ok;
137138
}
138139

@@ -141,11 +142,100 @@ Error MultimodalRunner::generate(
141142
const GenerationConfig& config,
142143
std::function<void(const std::string&)> token_callback,
143144
std::function<void(const Stats&)> stats_callback) {
145+
if (!prompt.empty()) {
146+
std::vector<MultimodalInput> inputs;
147+
inputs.emplace_back(MultimodalInput(prompt));
148+
return generate(inputs, config, token_callback, stats_callback);
149+
}
150+
151+
// Empty prompt: consume prefill_next_token_ and go straight to decode
144152
ET_CHECK_OR_RETURN_ERROR(
145-
!prompt.empty(), InvalidArgument, "Prompt cannot be empty");
146-
std::vector<MultimodalInput> inputs;
147-
inputs.emplace_back(MultimodalInput(prompt));
148-
return generate(inputs, config, token_callback, stats_callback);
153+
prefill_next_token_.has_value(),
154+
InvalidState,
155+
"Empty prompt requires a prior prefill() call");
156+
157+
if (!is_loaded()) {
158+
ET_CHECK_OK_OR_RETURN_ERROR(load());
159+
}
160+
161+
// Wrap the token_callback with print function
162+
std::function<void(const std::string&)> wrapped_callback =
163+
[token_callback, config](const std::string& piece) {
164+
if (!config.warming) {
165+
safe_printf(piece.c_str());
166+
fflush(stdout);
167+
}
168+
if (token_callback) {
169+
token_callback(piece);
170+
}
171+
};
172+
173+
stats_->inference_start_ms = time_in_ms();
174+
175+
uint64_t cur_token = prefill_next_token_.value();
176+
prefill_next_token_.reset();
177+
178+
stats_->first_token_ms = time_in_ms();
179+
stats_->prompt_eval_end_ms = time_in_ms();
180+
stats_->num_prompt_tokens = pos_;
181+
182+
auto decode_result = tokenizer_->decode(cur_token, cur_token);
183+
if (!decode_result.ok()) {
184+
ET_LOG(
185+
Error,
186+
"Tokenizers error code %d",
187+
static_cast<uint32_t>(decode_result.error()));
188+
return Error::InvalidArgument;
189+
}
190+
wrapped_callback(std::move(*decode_result));
191+
192+
// Resolve max_new_tokens based on config
193+
int64_t max_context_len = metadata_.at(kMaxContextLen);
194+
int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
195+
196+
ET_CHECK_OR_RETURN_ERROR(
197+
max_new_tokens > 0,
198+
InvalidArgument,
199+
"Max new tokens %d is less than or equal to 0",
200+
max_new_tokens);
201+
202+
// Set ignore_eos based on config
203+
text_token_generator_->set_ignore_eos(config.ignore_eos);
204+
205+
// Generate tokens using the text token generator
206+
std::vector<uint64_t> prompt_tokens = {cur_token};
207+
auto generate_result = text_token_generator_->generate(
208+
prompt_tokens,
209+
pos_,
210+
max_new_tokens -
211+
1, // Subtract 1 because prefill already generated 1 token
212+
config.temperature,
213+
wrapped_callback);
214+
if (!generate_result.ok()) {
215+
return generate_result.error();
216+
}
217+
int64_t num_generated_tokens = generate_result.get();
218+
219+
pos_ += num_generated_tokens;
220+
// Update stats
221+
stats_->num_generated_tokens = num_generated_tokens;
222+
// Finalize stats and call callback
223+
stats_->inference_end_ms = time_in_ms();
224+
225+
if (!config.warming) {
226+
printf("\n");
227+
}
228+
if (config.warming) {
229+
ET_LOG(Info, "Warmup run finished!");
230+
} else {
231+
// Do not print report during warmup
232+
print_report(*stats_);
233+
}
234+
if (stats_callback) {
235+
stats_callback(*stats_);
236+
}
237+
238+
return Error::Ok;
149239
}
150240

151241
Error MultimodalRunner::generate(

extension/llm/runner/multimodal_runner.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <cstdint>
1616
#include <functional>
1717
#include <memory>
18+
#include <optional>
1819
#include <string>
1920
#include <unordered_map>
2021

@@ -170,6 +171,7 @@ class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
170171
void reset() override {
171172
pos_ = 0;
172173
stats_->reset();
174+
prefill_next_token_.reset();
173175
}
174176

175177
~MultimodalRunner() override = default;
@@ -191,6 +193,7 @@ class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
191193
#endif
192194

193195
// Internal state
196+
std::optional<uint64_t> prefill_next_token_;
194197
int64_t pos_;
195198

196199
private:

extension/llm/runner/text_llm_runner.cpp

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ Error TextLLMRunner::generate(
7676
const GenerationConfig& config,
7777
std::function<void(const std::string&)> token_callback,
7878
std::function<void(const Stats&)> stats_callback) {
79-
// Prepare the inputs.
80-
// Use ones-initialized inputs.
81-
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
8279
if (!is_loaded()) {
8380
stats_->model_load_start_ms = time_in_ms();
8481
ET_CHECK_OK_OR_RETURN_ERROR(load());
@@ -112,37 +109,66 @@ Error TextLLMRunner::generate(
112109
stats_->inference_start_ms = time_in_ms();
113110
shouldStop_ = false;
114111

115-
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
116-
prompt,
117-
/*bos=*/config.num_bos,
118-
/*eos=*/config.num_eos);
112+
// Capture remaining KV cache capacity before prefill (pos_ will change)
113+
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
119114

120-
if (!encode_res.ok()) {
121-
ET_LOG(
122-
Error,
123-
"Failed to encode prompt %s. Tokenizers error code %d",
124-
prompt.c_str(),
125-
static_cast<uint32_t>(encode_res.error()));
126-
return Error::InvalidArgument;
127-
}
115+
uint64_t cur_token = 0;
116+
int num_prompt_tokens = 0;
117+
std::vector<uint64_t> prompt_tokens;
118+
119+
if (!prompt.empty()) {
120+
::tokenizers::Result<std::vector<uint64_t>> encode_res =
121+
tokenizer_->encode(
122+
prompt,
123+
/*bos=*/config.num_bos,
124+
/*eos=*/config.num_eos);
125+
126+
if (!encode_res.ok()) {
127+
ET_LOG(
128+
Error,
129+
"Failed to encode prompt %s. Tokenizers error code %d",
130+
prompt.c_str(),
131+
static_cast<uint32_t>(encode_res.error()));
132+
return Error::InvalidArgument;
133+
}
128134

129-
// encode the (string) prompt into tokens sequence
130-
std::vector<uint64_t> prompt_tokens = encode_res.get();
131-
int num_prompt_tokens = prompt_tokens.size();
135+
// encode the (string) prompt into tokens sequence
136+
prompt_tokens = encode_res.get();
137+
num_prompt_tokens = prompt_tokens.size();
138+
139+
ET_CHECK_OR_RETURN_ERROR(
140+
num_prompt_tokens >= 1,
141+
InvalidArgument,
142+
"Expected at least 1 prompt token");
143+
ET_CHECK_OR_RETURN_ERROR(
144+
num_prompt_tokens < max_context_len,
145+
InvalidArgument,
146+
"num_prompt_tokens %d >= max_context_len %" PRId64
147+
", Max seq length exceeded - please increase max seq len value in your export script",
148+
num_prompt_tokens,
149+
max_context_len);
150+
151+
// print prompts
152+
if (config.echo) {
153+
wrapped_callback(prompt);
154+
}
132155

133-
// Reduce max_context_len by pos_
134-
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
135-
ET_CHECK_OR_RETURN_ERROR(
136-
num_prompt_tokens >= 1,
137-
InvalidArgument,
138-
"Expected at least 1 prompt token");
139-
ET_CHECK_OR_RETURN_ERROR(
140-
num_prompt_tokens < max_context_len,
141-
InvalidArgument,
142-
"num_prompt_tokens %d >= max_context_len %" PRId64
143-
", Max seq length exceeded - please increase max seq len value in your export script",
144-
num_prompt_tokens,
145-
max_context_len);
156+
// Prefill first
157+
// Here feed all tokens to the model and get the next predicted token
158+
// after the prompt. After that we will enter generate loop.
159+
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_);
160+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
161+
cur_token = prefill_res.get();
162+
prefill_next_token_.reset();
163+
} else {
164+
// Empty prompt: consume token from a prior prefill() call
165+
ET_CHECK_OR_RETURN_ERROR(
166+
prefill_next_token_.has_value(),
167+
InvalidState,
168+
"Empty prompt requires a prior prefill() call");
169+
cur_token = prefill_next_token_.value();
170+
prefill_next_token_.reset();
171+
}
146172

147173
// Determine max_new_tokens using the GenerationConfig's resolve method,
148174
// then subtract pos_ for max_new_tokens.
@@ -152,28 +178,17 @@ Error TextLLMRunner::generate(
152178
ET_LOG(
153179
Info,
154180
"Max new tokens resolved: %d, given pos_ %" PRId64
155-
", num_prompt_tokens %zu, max_context_len %" PRId64,
181+
", num_prompt_tokens %d, max_context_len %" PRId64,
156182
max_new_tokens,
157183
pos_,
158-
prompt_tokens.size(),
184+
num_prompt_tokens,
159185
max_context_len);
160186
ET_CHECK_OR_RETURN_ERROR(
161187
max_new_tokens > 0,
162188
InvalidArgument,
163189
"Max new tokens %d is less than or equal to 0",
164190
max_new_tokens);
165191

166-
// Prefill first
167-
// Here feed all tokens to the model and get the next predicted token
168-
// after the prompt. After that we will enter generate loop.
169-
170-
// print prompts
171-
if (config.echo) {
172-
wrapped_callback(prompt);
173-
}
174-
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_);
175-
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
176-
uint64_t cur_token = prefill_res.get();
177192
stats_->first_token_ms = time_in_ms();
178193
stats_->prompt_eval_end_ms = time_in_ms();
179194

@@ -225,7 +240,8 @@ Error TextLLMRunner::generate(
225240
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
226241
}
227242

228-
stats_->num_prompt_tokens = num_prompt_tokens;
243+
stats_->num_prompt_tokens =
244+
prompt.empty() ? static_cast<int64_t>(pos_) : num_prompt_tokens;
229245
stats_->num_generated_tokens = num_generated_tokens;
230246

231247
if (config.warming) {
@@ -260,6 +276,7 @@ Error TextLLMRunner::prefill(
260276
std::vector<uint64_t> tokens = encode_res.get();
261277
auto prefill_res = text_prefiller_->prefill(tokens, pos_);
262278
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
279+
prefill_next_token_ = prefill_res.get();
263280
num_bos = 0;
264281
num_eos = 0;
265282
}
@@ -295,6 +312,7 @@ void TextLLMRunner::stop() {
295312
void TextLLMRunner::reset() {
296313
stats_->reset();
297314
pos_ = 0;
315+
prefill_next_token_.reset();
298316
}
299317

300318
} // namespace executorch::extension::llm

extension/llm/runner/text_llm_runner.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cstdint>
1515
#include <functional>
1616
#include <memory>
17+
#include <optional>
1718
#include <string>
1819
#include <unordered_map>
1920

@@ -189,6 +190,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
189190
// Deprecated, we should rely on the temperature in GenerationConfig instead.
190191
float temperature_ = -1.0f;
191192

193+
// Token predicted by the last prefill() call, consumed by generate("").
194+
std::optional<uint64_t> prefill_next_token_;
195+
192196
// The position in KV cache of the input, starting from 0.
193197
int64_t pos_ = 0;
194198
};

0 commit comments

Comments
 (0)