Skip to content

Commit 72d6034

Browse files
committed
Extract decode_from_token to eliminate duplicate decode loop
Both generate(string) empty-prompt path and generate(vector) shared an identical decode loop (decode first token, resolve max_new_tokens, run text_token_generator, update stats). Extract this into a private decode_from_token() method called by both paths. This PR was authored with the assistance of Claude.
1 parent f865cff commit 72d6034

2 files changed

Lines changed: 79 additions & 124 deletions

File tree

extension/llm/runner/multimodal_runner.cpp

Lines changed: 72 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -126,44 +126,11 @@ Result<uint64_t> MultimodalRunner::prefill(
126126
return last_token;
127127
}
128128

129-
Error MultimodalRunner::generate(
130-
const std::string& prompt,
129+
Error MultimodalRunner::decode_from_token(
130+
uint64_t cur_token,
131131
const GenerationConfig& config,
132-
std::function<void(const std::string&)> token_callback,
132+
std::function<void(const std::string&)> wrapped_callback,
133133
std::function<void(const Stats&)> stats_callback) {
134-
if (!prompt.empty()) {
135-
std::vector<MultimodalInput> inputs;
136-
inputs.emplace_back(MultimodalInput(prompt));
137-
return generate(inputs, config, token_callback, stats_callback);
138-
}
139-
140-
// Empty prompt: consume prefill_next_token_ and go straight to decode
141-
ET_CHECK_OR_RETURN_ERROR(
142-
prefill_next_token_.has_value(),
143-
InvalidState,
144-
"Empty prompt requires a prior prefill() call");
145-
146-
if (!is_loaded()) {
147-
ET_CHECK_OK_OR_RETURN_ERROR(load());
148-
}
149-
150-
// Wrap the token_callback with print function
151-
std::function<void(const std::string&)> wrapped_callback =
152-
[token_callback, config](const std::string& piece) {
153-
if (!config.warming) {
154-
safe_printf(piece.c_str());
155-
fflush(stdout);
156-
}
157-
if (token_callback) {
158-
token_callback(piece);
159-
}
160-
};
161-
162-
stats_->inference_start_ms = time_in_ms();
163-
164-
uint64_t cur_token = prefill_next_token_.value();
165-
prefill_next_token_.reset();
166-
167134
stats_->first_token_ms = time_in_ms();
168135
stats_->prompt_eval_end_ms = time_in_ms();
169136
stats_->num_prompt_tokens = pos_;
@@ -178,10 +145,22 @@ Error MultimodalRunner::generate(
178145
}
179146
wrapped_callback(std::move(*decode_result));
180147

148+
RUNNER_ET_LOG(
149+
config.warming,
150+
"RSS after multimodal input processing: %f MiB (0 if unsupported)",
151+
get_rss_bytes() / 1024.0 / 1024.0);
152+
181153
// Resolve max_new_tokens based on config
182154
int64_t max_context_len = metadata_.at(kMaxContextLen);
183155
int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
184156

157+
ET_LOG(
158+
Info,
159+
"Max new tokens resolved: %d, pos_ %" PRId64 ", max_context_len %" PRId64,
160+
max_new_tokens,
161+
pos_,
162+
max_context_len);
163+
185164
ET_CHECK_OR_RETURN_ERROR(
186165
max_new_tokens > 0,
187166
InvalidArgument,
@@ -194,12 +173,12 @@ Error MultimodalRunner::generate(
194173
// Generate tokens using the text token generator
195174
std::vector<uint64_t> prompt_tokens = {cur_token};
196175
auto generate_result = text_token_generator_->generate(
197-
prompt_tokens,
198-
pos_,
199-
max_new_tokens -
176+
/*tokens=*/prompt_tokens,
177+
/*start_pos=*/pos_,
178+
/*max_new_tokens=*/max_new_tokens -
200179
1, // Subtract 1 because prefill already generated 1 token
201-
config.temperature,
202-
wrapped_callback);
180+
/*temperature=*/config.temperature,
181+
/*token_callback=*/wrapped_callback);
203182
if (!generate_result.ok()) {
204183
return generate_result.error();
205184
}
@@ -211,22 +190,73 @@ Error MultimodalRunner::generate(
211190
// Finalize stats and call callback
212191
stats_->inference_end_ms = time_in_ms();
213192

193+
#ifdef CUDA_AVAILABLE
194+
cuda_memory_tracker_->log_sample("after_generate");
195+
stats_->gpu_free_after_generate_bytes =
196+
cuda_memory_tracker_->last_free_bytes();
197+
// update peak in case it changed after generation
198+
stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb();
199+
#endif
200+
214201
if (!config.warming) {
215202
printf("\n");
216203
}
204+
217205
if (config.warming) {
218206
ET_LOG(Info, "Warmup run finished!");
219207
} else {
220208
// Do not print report during warmup
221209
print_report(*stats_);
222210
}
211+
223212
if (stats_callback) {
224213
stats_callback(*stats_);
225214
}
226215

227216
return Error::Ok;
228217
}
229218

219+
Error MultimodalRunner::generate(
220+
const std::string& prompt,
221+
const GenerationConfig& config,
222+
std::function<void(const std::string&)> token_callback,
223+
std::function<void(const Stats&)> stats_callback) {
224+
if (!prompt.empty()) {
225+
std::vector<MultimodalInput> inputs;
226+
inputs.emplace_back(MultimodalInput(prompt));
227+
return generate(inputs, config, token_callback, stats_callback);
228+
}
229+
230+
// Empty prompt: consume prefill_next_token_ and go straight to decode
231+
ET_CHECK_OR_RETURN_ERROR(
232+
prefill_next_token_.has_value(),
233+
InvalidState,
234+
"Empty prompt requires a prior prefill() call");
235+
236+
if (!is_loaded()) {
237+
ET_CHECK_OK_OR_RETURN_ERROR(load());
238+
}
239+
240+
// Wrap the token_callback with print function
241+
std::function<void(const std::string&)> wrapped_callback =
242+
[token_callback, config](const std::string& piece) {
243+
if (!config.warming) {
244+
safe_printf(piece.c_str());
245+
fflush(stdout);
246+
}
247+
if (token_callback) {
248+
token_callback(piece);
249+
}
250+
};
251+
252+
stats_->inference_start_ms = time_in_ms();
253+
254+
uint64_t cur_token = prefill_next_token_.value();
255+
prefill_next_token_.reset();
256+
257+
return decode_from_token(cur_token, config, wrapped_callback, stats_callback);
258+
}
259+
230260
Error MultimodalRunner::generate(
231261
const std::vector<MultimodalInput>& inputs,
232262
const GenerationConfig& config,
@@ -275,89 +305,7 @@ Error MultimodalRunner::generate(
275305
ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error());
276306
uint64_t cur_token = prefill_result.get();
277307

278-
stats_->first_token_ms = time_in_ms();
279-
stats_->prompt_eval_end_ms = time_in_ms();
280-
stats_->num_prompt_tokens = pos_;
281-
282-
auto decode_result = tokenizer_->decode(cur_token, cur_token);
283-
if (!decode_result.ok()) {
284-
ET_LOG(
285-
Error,
286-
"Tokenizers error code %d",
287-
static_cast<uint32_t>(decode_result.error()));
288-
return Error::InvalidArgument;
289-
}
290-
wrapped_callback(std::move(*decode_result));
291-
292-
RUNNER_ET_LOG(
293-
config.warming,
294-
"RSS after multimodal input processing: %f MiB (0 if unsupported)",
295-
get_rss_bytes() / 1024.0 / 1024.0);
296-
297-
// Resolve max_new_tokens based on config
298-
int64_t max_context_len = metadata_.at(kMaxContextLen);
299-
int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
300-
301-
ET_LOG(
302-
Info,
303-
"Max new tokens resolved: %d, pos_ %" PRId64 ", max_context_len %" PRId64,
304-
max_new_tokens,
305-
pos_,
306-
max_context_len);
307-
308-
ET_CHECK_OR_RETURN_ERROR(
309-
max_new_tokens > 0,
310-
InvalidArgument,
311-
"Max new tokens %d is less than or equal to 0",
312-
max_new_tokens);
313-
314-
// Set ignore_eos based on config
315-
text_token_generator_->set_ignore_eos(config.ignore_eos);
316-
317-
// Generate tokens using the text token generator
318-
std::vector<uint64_t> prompt_tokens = {cur_token};
319-
auto generate_result = text_token_generator_->generate(
320-
/*tokens=*/prompt_tokens,
321-
/*start_pos=*/pos_,
322-
/*max_new_tokens=*/max_new_tokens -
323-
1, // Subtract 1 because prefill already generated 1 token
324-
/*temperature=*/config.temperature,
325-
/*token_callback=*/wrapped_callback);
326-
if (!generate_result.ok()) {
327-
return generate_result.error();
328-
}
329-
int64_t num_generated_tokens = generate_result.get();
330-
331-
pos_ += num_generated_tokens;
332-
// Update stats
333-
stats_->num_generated_tokens = num_generated_tokens;
334-
// Finalize stats and call callback
335-
stats_->inference_end_ms = time_in_ms();
336-
337-
#ifdef CUDA_AVAILABLE
338-
cuda_memory_tracker_->log_sample("after_generate");
339-
stats_->gpu_free_after_generate_bytes =
340-
cuda_memory_tracker_->last_free_bytes();
341-
// update peak in case it changed after generation
342-
stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb();
343-
#endif
344-
345-
if (!config.warming) {
346-
printf("\n");
347-
}
348-
349-
if (config.warming) {
350-
ET_LOG(Info, "Warmup run finished!");
351-
} else {
352-
// Do not print report during warmup
353-
print_report(*stats_);
354-
}
355-
356-
if (stats_callback) {
357-
stats_callback(*stats_);
358-
}
359-
360-
return Error::Ok;
308+
return decode_from_token(cur_token, config, wrapped_callback, stats_callback);
361309
}
362310

363311
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_runner.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ class ET_EXPERIMENTAL MultimodalRunner : public IRunner {
196196
// Internal state
197197
std::optional<uint64_t> prefill_next_token_;
198198
int64_t pos_;
199+
200+
private:
201+
::executorch::runtime::Error decode_from_token(
202+
uint64_t cur_token,
203+
const GenerationConfig& config,
204+
std::function<void(const std::string&)> wrapped_callback,
205+
std::function<void(const Stats&)> stats_callback);
199206
};
200207

201208
} // namespace llm

0 commit comments

Comments
 (0)