@@ -126,44 +126,11 @@ Result<uint64_t> MultimodalRunner::prefill(
126126 return last_token;
127127}
128128
129- Error MultimodalRunner::generate (
130- const std::string& prompt ,
129+ Error MultimodalRunner::decode_from_token (
130+ uint64_t cur_token ,
131131 const GenerationConfig& config,
132- std::function<void (const std::string&)> token_callback ,
132+ std::function<void (const std::string&)> wrapped_callback ,
133133 std::function<void(const Stats&)> stats_callback) {
134- if (!prompt.empty ()) {
135- std::vector<MultimodalInput> inputs;
136- inputs.emplace_back (MultimodalInput (prompt));
137- return generate (inputs, config, token_callback, stats_callback);
138- }
139-
140- // Empty prompt: consume prefill_next_token_ and go straight to decode
141- ET_CHECK_OR_RETURN_ERROR (
142- prefill_next_token_.has_value (),
143- InvalidState,
144- " Empty prompt requires a prior prefill() call" );
145-
146- if (!is_loaded ()) {
147- ET_CHECK_OK_OR_RETURN_ERROR (load ());
148- }
149-
150- // Wrap the token_callback with print function
151- std::function<void (const std::string&)> wrapped_callback =
152- [token_callback, config](const std::string& piece) {
153- if (!config.warming ) {
154- safe_printf (piece.c_str ());
155- fflush (stdout);
156- }
157- if (token_callback) {
158- token_callback (piece);
159- }
160- };
161-
162- stats_->inference_start_ms = time_in_ms ();
163-
164- uint64_t cur_token = prefill_next_token_.value ();
165- prefill_next_token_.reset ();
166-
167134 stats_->first_token_ms = time_in_ms ();
168135 stats_->prompt_eval_end_ms = time_in_ms ();
169136 stats_->num_prompt_tokens = pos_;
@@ -178,10 +145,22 @@ Error MultimodalRunner::generate(
178145 }
179146 wrapped_callback (std::move (*decode_result));
180147
148+ RUNNER_ET_LOG (
149+ config.warming ,
150+ " RSS after multimodal input processing: %f MiB (0 if unsupported)" ,
151+ get_rss_bytes () / 1024.0 / 1024.0 );
152+
181153 // Resolve max_new_tokens based on config
182154 int64_t max_context_len = metadata_.at (kMaxContextLen );
183155 int32_t max_new_tokens = config.resolve_max_new_tokens (max_context_len, pos_);
184156
157+ ET_LOG (
158+ Info,
159+ " Max new tokens resolved: %d, pos_ %" PRId64 " , max_context_len %" PRId64,
160+ max_new_tokens,
161+ pos_,
162+ max_context_len);
163+
185164 ET_CHECK_OR_RETURN_ERROR (
186165 max_new_tokens > 0 ,
187166 InvalidArgument,
@@ -194,12 +173,12 @@ Error MultimodalRunner::generate(
194173 // Generate tokens using the text token generator
195174 std::vector<uint64_t > prompt_tokens = {cur_token};
196175 auto generate_result = text_token_generator_->generate (
197- prompt_tokens,
198- pos_,
199- max_new_tokens -
176+ /* tokens= */ prompt_tokens,
177+ /* start_pos= */ pos_,
178+ /* max_new_tokens= */ max_new_tokens -
200179 1 , // Subtract 1 because prefill already generated 1 token
201- config.temperature ,
202- wrapped_callback);
180+ /* temperature= */ config.temperature ,
181+ /* token_callback= */ wrapped_callback);
203182 if (!generate_result.ok ()) {
204183 return generate_result.error ();
205184 }
@@ -211,22 +190,73 @@ Error MultimodalRunner::generate(
211190 // Finalize stats and call callback
212191 stats_->inference_end_ms = time_in_ms ();
213192
193+ #ifdef CUDA_AVAILABLE
194+ cuda_memory_tracker_->log_sample (" after_generate" );
195+ stats_->gpu_free_after_generate_bytes =
196+ cuda_memory_tracker_->last_free_bytes ();
197+ // update peak in case it changed after generation
198+ stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb ();
199+ #endif
200+
214201 if (!config.warming ) {
215202 printf (" \n " );
216203 }
204+
217205 if (config.warming ) {
218206 ET_LOG (Info, " Warmup run finished!" );
219207 } else {
220208 // Do not print report during warmup
221209 print_report (*stats_);
222210 }
211+
223212 if (stats_callback) {
224213 stats_callback (*stats_);
225214 }
226215
227216 return Error::Ok;
228217}
229218
219+ Error MultimodalRunner::generate (
220+ const std::string& prompt,
221+ const GenerationConfig& config,
222+ std::function<void (const std::string&)> token_callback,
223+ std::function<void(const Stats&)> stats_callback) {
224+ if (!prompt.empty ()) {
225+ std::vector<MultimodalInput> inputs;
226+ inputs.emplace_back (MultimodalInput (prompt));
227+ return generate (inputs, config, token_callback, stats_callback);
228+ }
229+
230+ // Empty prompt: consume prefill_next_token_ and go straight to decode
231+ ET_CHECK_OR_RETURN_ERROR (
232+ prefill_next_token_.has_value (),
233+ InvalidState,
234+ " Empty prompt requires a prior prefill() call" );
235+
236+ if (!is_loaded ()) {
237+ ET_CHECK_OK_OR_RETURN_ERROR (load ());
238+ }
239+
240+ // Wrap the token_callback with print function
241+ std::function<void (const std::string&)> wrapped_callback =
242+ [token_callback, config](const std::string& piece) {
243+ if (!config.warming ) {
244+ safe_printf (piece.c_str ());
245+ fflush (stdout);
246+ }
247+ if (token_callback) {
248+ token_callback (piece);
249+ }
250+ };
251+
252+ stats_->inference_start_ms = time_in_ms ();
253+
254+ uint64_t cur_token = prefill_next_token_.value ();
255+ prefill_next_token_.reset ();
256+
257+ return decode_from_token (cur_token, config, wrapped_callback, stats_callback);
258+ }
259+
230260Error MultimodalRunner::generate (
231261 const std::vector<MultimodalInput>& inputs,
232262 const GenerationConfig& config,
@@ -275,89 +305,7 @@ Error MultimodalRunner::generate(
275305 ET_CHECK_OK_OR_RETURN_ERROR (prefill_result.error ());
276306 uint64_t cur_token = prefill_result.get ();
277307
278- stats_->first_token_ms = time_in_ms ();
279- stats_->prompt_eval_end_ms = time_in_ms ();
280- stats_->num_prompt_tokens = pos_;
281-
282- auto decode_result = tokenizer_->decode (cur_token, cur_token);
283- if (!decode_result.ok ()) {
284- ET_LOG (
285- Error,
286- " Tokenizers error code %d" ,
287- static_cast <uint32_t >(decode_result.error ()));
288- return Error::InvalidArgument;
289- }
290- wrapped_callback (std::move (*decode_result));
291-
292- RUNNER_ET_LOG (
293- config.warming ,
294- " RSS after multimodal input processing: %f MiB (0 if unsupported)" ,
295- get_rss_bytes () / 1024.0 / 1024.0 );
296-
297- // Resolve max_new_tokens based on config
298- int64_t max_context_len = metadata_.at (kMaxContextLen );
299- int32_t max_new_tokens = config.resolve_max_new_tokens (max_context_len, pos_);
300-
301- ET_LOG (
302- Info,
303- " Max new tokens resolved: %d, pos_ %" PRId64 " , max_context_len %" PRId64,
304- max_new_tokens,
305- pos_,
306- max_context_len);
307-
308- ET_CHECK_OR_RETURN_ERROR (
309- max_new_tokens > 0 ,
310- InvalidArgument,
311- " Max new tokens %d is less than or equal to 0" ,
312- max_new_tokens);
313-
314- // Set ignore_eos based on config
315- text_token_generator_->set_ignore_eos (config.ignore_eos );
316-
317- // Generate tokens using the text token generator
318- std::vector<uint64_t > prompt_tokens = {cur_token};
319- auto generate_result = text_token_generator_->generate (
320- /* tokens=*/ prompt_tokens,
321- /* start_pos=*/ pos_,
322- /* max_new_tokens=*/ max_new_tokens -
323- 1 , // Subtract 1 because prefill already generated 1 token
324- /* temperature=*/ config.temperature ,
325- /* token_callback=*/ wrapped_callback);
326- if (!generate_result.ok ()) {
327- return generate_result.error ();
328- }
329- int64_t num_generated_tokens = generate_result.get ();
330-
331- pos_ += num_generated_tokens;
332- // Update stats
333- stats_->num_generated_tokens = num_generated_tokens;
334- // Finalize stats and call callback
335- stats_->inference_end_ms = time_in_ms ();
336-
337- #ifdef CUDA_AVAILABLE
338- cuda_memory_tracker_->log_sample (" after_generate" );
339- stats_->gpu_free_after_generate_bytes =
340- cuda_memory_tracker_->last_free_bytes ();
341- // update peak in case it changed after generation
342- stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb ();
343- #endif
344-
345- if (!config.warming ) {
346- printf (" \n " );
347- }
348-
349- if (config.warming ) {
350- ET_LOG (Info, " Warmup run finished!" );
351- } else {
352- // Do not print report during warmup
353- print_report (*stats_);
354- }
355-
356- if (stats_callback) {
357- stats_callback (*stats_);
358- }
359-
360- return Error::Ok;
308+ return decode_from_token (cur_token, config, wrapped_callback, stats_callback);
361309}
362310
363311} // namespace executorch::extension::llm
0 commit comments