@@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
116116 max_logging .log ("Could not retrieve Git commit hash." )
117117
118118 checkpoint_loader = LTX2Checkpointer (config = config )
119+ load_time = 0.0
119120 if pipeline is None :
121+ t0_load = time .perf_counter ()
120122 # Use the config flag to determine if the upsampler should be loaded
121123 run_latent_upsampler = getattr (config , "run_latent_upsampler" , False )
122124 pipeline , _ , _ = checkpoint_loader .load_checkpoint (load_upsampler = run_latent_upsampler )
@@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
145147 scan_layers = config .scan_layers ,
146148 dtype = config .weights_dtype ,
147149 )
150+ load_time = time .perf_counter () - t0_load
148151
149152 pipeline .enable_vae_slicing ()
150153 pipeline .enable_vae_tiling ()
@@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
162165 f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } "
163166 )
164167
165- out = call_pipeline (config , pipeline , prompt , negative_prompt )
166-
167- # out should have .frames and .audio
168- videos = out .frames if hasattr (out , "frames" ) else out [0 ]
169- audios = out .audio if hasattr (out , "audio" ) else None
170-
171168 max_logging .log ("===================== Model details =======================" )
172169 max_logging .log (f"model name: { getattr (config , 'model_name' , 'ltx-video' )} " )
173170 max_logging .log (f"model path: { config .pretrained_model_name_or_path } " )
@@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
179176 max_logging .log (f"per_device_batch_size: { config .per_device_batch_size } " )
180177 max_logging .log ("============================================================" )
181178
179+ original_enable_profiler = config .get_keys ().get ("enable_profiler" , False )
180+ original_enable_mld = config .get_keys ().get ("enable_ml_diagnostics" , False )
181+ original_num_steps = config .get_keys ().get ("num_inference_steps" , 40 )
182+
183+ # ---------------------------------------------------------
184+ # Run 1: Warmup Compilation (Original steps, NO profiling)
185+ # ---------------------------------------------------------
186+ config .get_keys ()["enable_profiler" ] = False
187+ config .get_keys ()["enable_ml_diagnostics" ] = False
188+
189+ max_logging .log (f"🚀 Starting warmup compilation pass ({ original_num_steps } steps)..." )
190+ _ = call_pipeline (config , pipeline , prompt , negative_prompt )
191+
182192 compile_time = time .perf_counter () - s0
183193 max_logging .log (f"compile_time: { compile_time } " )
184194 if writer and jax .process_index () == 0 :
185195 writer .add_scalar ("inference/compile_time" , compile_time , global_step = 0 )
186196
197+ # ---------------------------------------------------------
198+ # Run 2: Actual Generation (Original steps, NO profiling)
199+ # ---------------------------------------------------------
200+
201+ s0 = time .perf_counter ()
202+ max_logging .log ("🚀 Starting actual full-length generation pass..." )
203+ out = call_pipeline (config , pipeline , prompt , negative_prompt )
204+ generation_time = time .perf_counter () - s0
205+ max_logging .log (f"generation_time: { generation_time } " )
206+ if writer and jax .process_index () == 0 :
207+ writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
208+ num_devices = jax .device_count ()
209+ num_videos = num_devices * config .per_device_batch_size
210+ if num_videos > 0 :
211+ generation_time_per_video = generation_time / num_videos
212+ writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
213+ max_logging .log (f"generation time per video: { generation_time_per_video } " )
214+ else :
215+ max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
216+
217+ # out should have .frames and .audio
218+ videos = out .frames if hasattr (out , "frames" ) else out [0 ]
219+ audios = out .audio if hasattr (out , "audio" ) else None
220+
187221 saved_video_path = []
188222 audio_sample_rate = (
189223 getattr (pipeline .vocoder .config , "output_sampling_rate" , 24000 ) if hasattr (pipeline , "vocoder" ) else 24000
@@ -210,29 +244,51 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
210244 if config .output_dir .startswith ("gs://" ):
211245 upload_video_to_gcs (os .path .join (config .output_dir , config .run_name ), video_path )
212246
213- s0 = time .perf_counter ()
214- call_pipeline (config , pipeline , prompt , negative_prompt )
215- generation_time = time .perf_counter () - s0
216- max_logging .log (f"generation_time: { generation_time } " )
217- if writer and jax .process_index () == 0 :
218- writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
219- num_devices = jax .device_count ()
220- num_videos = num_devices * config .per_device_batch_size
221- if num_videos > 0 :
222- generation_time_per_video = generation_time / num_videos
223- writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
224- max_logging .log (f"generation time per video: { generation_time_per_video } " )
225- else :
226- max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
247+ max_logging .log (
248+ f"\n { '=' * 50 } \n "
249+ f" TIMING SUMMARY\n "
250+ f"{ '=' * 50 } \n "
251+ f" Load (checkpoint): { load_time :>7.1f} s\n "
252+ f" Compile: { compile_time :>7.1f} s\n "
253+ f" { '─' * 40 } \n "
254+ f" Inference: { generation_time :>7.1f} s\n "
255+ f"{ '=' * 50 } "
256+ )
227257
228- s0 = time .perf_counter ()
229- if max_utils .profiler_enabled (config ):
230- with max_utils .Profiler (config ):
231- call_pipeline (config , pipeline , prompt , negative_prompt )
232- generation_time_with_profiler = time .perf_counter () - s0
233- max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
234- if writer and jax .process_index () == 0 :
235- writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
258+ # Free memory before profiling
259+ del out
260+ del videos
261+ del audios
262+
263+ # ---------------------------------------------------------
264+ # Run 3: Profiling Run (Only if profiling was originally enabled)
265+ # ---------------------------------------------------------
266+ if original_enable_profiler or original_enable_mld :
267+ skip_first_n_steps_for_profiler = config .get_keys ().get ("skip_first_n_steps_for_profiler" , 0 )
268+ if skip_first_n_steps_for_profiler != 0 :
269+ max_logging .log (
270+ "\n ⚠️ WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n "
271+ )
272+
273+ profiling_steps = config .get_keys ().get ("profiler_steps" , 5 )
274+
275+ config .get_keys ()["enable_profiler" ] = False
276+ config .get_keys ()["enable_ml_diagnostics" ] = False
277+ config .get_keys ()["num_inference_steps" ] = profiling_steps
278+
279+ max_logging .log (f"🚀 Warmup for profiling pass ({ profiling_steps } steps)..." )
280+ _ = call_pipeline (config , pipeline , prompt , negative_prompt )
281+
282+ config .get_keys ()["enable_profiler" ] = original_enable_profiler
283+ config .get_keys ()["enable_ml_diagnostics" ] = original_enable_mld
284+
285+ max_logging .log (f"🚀 Starting Profiling run ({ profiling_steps } steps)..." )
286+ profiler = max_utils .Profiler (config , session_name = f"denoise_profile_{ profiling_steps } _steps" )
287+ profiler .start ()
288+
289+ _ = call_pipeline (config , pipeline , prompt , negative_prompt )
290+
291+ profiler .stop ()
236292
237293 return saved_video_path
238294
0 commit comments