@@ -255,6 +255,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
255255 f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } "
256256 )
257257 videos = call_pipeline (config , pipeline , prompt , negative_prompt )
258+ if isinstance (videos , tuple ):
259+ videos = videos [0 ]
258260
259261 max_logging .log ("===================== Model details =======================" )
260262 max_logging .log (f"model name: { config .model_name } " )
@@ -278,7 +280,12 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
278280 upload_video_to_gcs (os .path .join (config .output_dir , config .run_name ), video_path )
279281
280282 s0 = time .perf_counter ()
281- videos = call_pipeline (config , pipeline , prompt , negative_prompt )
283+ outputs = call_pipeline (config , pipeline , prompt , negative_prompt )
284+ if isinstance (outputs , tuple ):
285+ videos , trace = outputs
286+ else :
287+ videos = outputs
288+ trace = {}
282289 generation_time = time .perf_counter () - s0
283290 max_logging .log (f"generation_time: { generation_time } " )
284291 if writer and jax .process_index () == 0 :
@@ -291,18 +298,39 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
291298 max_logging .log (f"generation time per video: { generation_time_per_video } " )
292299 else :
293300 max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
294- max_logging .log (
295- f"\n { '=' * 50 } \n "
296- f" TIMING SUMMARY\n "
297- f"{ '=' * 50 } \n "
298- f" Load (checkpoint): { load_time :>7.1f} s\n "
299- f" Compile: { compile_time :>7.1f} s\n "
300- f" { '─' * 40 } \n "
301- f" Inference: { generation_time :>7.1f} s\n "
302- f"{ '=' * 50 } "
303- )
301+ summary = [
302+ f"\n { '=' * 50 } " ,
303+ " TIMING SUMMARY" ,
304+ f"{ '=' * 50 } " ,
305+ f" Load (checkpoint): { load_time :>7.1f} s" ,
306+ f" Compile: { compile_time :>7.1f} s" ,
307+ f" { '─' * 40 } " ,
308+ f" Inference: { generation_time :>7.1f} s" ,
309+ ]
310+ if trace :
311+ summary .extend ([
312+ f" Conditioning: { trace .get ('conditioning' , 0.0 ):>7.1f} s" ,
313+ f" Denoise Total: { trace .get ('denoise_total' , 0.0 ):>7.1f} s" ,
314+ f" VAE Decode: { trace .get ('vae_decode' , 0.0 ):>7.1f} s" ,
315+ ])
316+ summary .append (f"{ '=' * 50 } " )
317+ max_logging .log ("\n " .join (summary ))
304318
305- videos = call_pipeline (config , pipeline , prompt , negative_prompt )
319+ s0 = time .perf_counter ()
320+ if max_utils .profiler_enabled (config ):
321+ # Injecting user requested XLA tracing flags
322+ xla_flags = os .environ .get ("XLA_FLAGS" , "" )
323+ new_flags = "--xla_enable_mxu_trace=true --xla_jf_dump_llo_html=true --xla_tpu_enable_llo_profiling=true"
324+ os .environ ["XLA_FLAGS" ] = f"{ xla_flags } { new_flags } "
325+ max_logging .log (f"Injected XLA_FLAGS for profiling: { new_flags } " )
326+
327+ videos = call_pipeline (config , pipeline , prompt , negative_prompt )
328+ if isinstance (videos , tuple ):
329+ videos = videos [0 ]
330+ generation_time_with_profiler = time .perf_counter () - s0
331+ max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
332+ if writer and jax .process_index () == 0 :
333+ writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
306334
307335 return saved_video_path
308336
0 commit comments