@@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
116116 max_logging .log ("Could not retrieve Git commit hash." )
117117
118118 checkpoint_loader = LTX2Checkpointer (config = config )
119+ load_time = 0.0
119120 if pipeline is None :
121+ t0_load = time .perf_counter ()
120122 # Use the config flag to determine if the upsampler should be loaded
121123 run_latent_upsampler = getattr (config , "run_latent_upsampler" , False )
122124 pipeline , _ , _ = checkpoint_loader .load_checkpoint (load_upsampler = run_latent_upsampler )
@@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
145147 scan_layers = config .scan_layers ,
146148 dtype = config .weights_dtype ,
147149 )
150+ load_time = time .perf_counter () - t0_load
148151
149152 pipeline .enable_vae_slicing ()
150153 pipeline .enable_vae_tiling ()
@@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
162165 f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } "
163166 )
164167
165- out = call_pipeline (config , pipeline , prompt , negative_prompt )
166-
167- # out should have .frames and .audio
168- videos = out .frames if hasattr (out , "frames" ) else out [0 ]
169- audios = out .audio if hasattr (out , "audio" ) else None
170-
171168 max_logging .log ("===================== Model details =======================" )
172169 max_logging .log (f"model name: { getattr (config , 'model_name' , 'ltx-video' )} " )
173170 max_logging .log (f"model path: { config .pretrained_model_name_or_path } " )
@@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
179176 max_logging .log (f"per_device_batch_size: { config .per_device_batch_size } " )
180177 max_logging .log ("============================================================" )
181178
179+ original_enable_profiler = config .get_keys ().get ("enable_profiler" , False )
180+ original_enable_mld = config .get_keys ().get ("enable_ml_diagnostics" , False )
181+ original_num_steps = config .get_keys ().get ("num_inference_steps" , 40 )
182+
183+ # ---------------------------------------------------------
184+ # Run 1: Warmup Compilation (Original steps, NO profiling)
185+ # ---------------------------------------------------------
186+ config .get_keys ()["enable_profiler" ] = False
187+ config .get_keys ()["enable_ml_diagnostics" ] = False
188+
189+ max_logging .log (f"🚀 Starting warmup compilation pass ({ original_num_steps } steps)..." )
190+ _ = call_pipeline (config , pipeline , prompt , negative_prompt )
191+
182192 compile_time = time .perf_counter () - s0
183193 max_logging .log (f"compile_time: { compile_time } " )
184194 if writer and jax .process_index () == 0 :
185195 writer .add_scalar ("inference/compile_time" , compile_time , global_step = 0 )
186196
197+ # ---------------------------------------------------------
198+ # Run 2: Actual Generation (Original steps, NO profiling)
199+ # ---------------------------------------------------------
200+
201+ s0 = time .perf_counter ()
202+ max_logging .log ("🚀 Starting actual full-length generation pass..." )
203+ out = call_pipeline (config , pipeline , prompt , negative_prompt )
204+ generation_time = time .perf_counter () - s0
205+ max_logging .log (f"generation_time: { generation_time } " )
206+ if writer and jax .process_index () == 0 :
207+ writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
208+ num_devices = jax .device_count ()
209+ num_videos = num_devices * config .per_device_batch_size
210+ if num_videos > 0 :
211+ generation_time_per_video = generation_time / num_videos
212+ writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
213+ max_logging .log (f"generation time per video: { generation_time_per_video } " )
214+ else :
215+ max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
216+
217+ # out should have .frames and .audio
218+ videos = out .frames if hasattr (out , "frames" ) else out [0 ]
219+ audios = out .audio if hasattr (out , "audio" ) else None
220+
187221 saved_video_path = []
188222 audio_sample_rate = (
189223 getattr (pipeline .vocoder .config , "output_sampling_rate" , 24000 ) if hasattr (pipeline , "vocoder" ) else 24000
@@ -210,29 +244,68 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
210244 if config .output_dir .startswith ("gs://" ):
211245 upload_video_to_gcs (os .path .join (config .output_dir , config .run_name ), video_path )
212246
213- s0 = time .perf_counter ()
214- call_pipeline (config , pipeline , prompt , negative_prompt )
215- generation_time = time .perf_counter () - s0
216- max_logging .log (f"generation_time: { generation_time } " )
217- if writer and jax .process_index () == 0 :
218- writer .add_scalar ("inference/generation_time" , generation_time , global_step = 0 )
219- num_devices = jax .device_count ()
220- num_videos = num_devices * config .per_device_batch_size
221- if num_videos > 0 :
222- generation_time_per_video = generation_time / num_videos
223- writer .add_scalar ("inference/generation_time_per_video" , generation_time_per_video , global_step = 0 )
224- max_logging .log (f"generation time per video: { generation_time_per_video } " )
225- else :
226- max_logging .log ("Warning: Number of videos is zero, cannot calculate generation_time_per_video." )
247+ timing_str = (
248+ f"\n { '=' * 50 } \n "
249+ f" TIMING SUMMARY\n "
250+ f"{ '=' * 50 } \n "
251+ f" Load (checkpoint): { load_time :>7.1f} s\n "
252+ f" Compile: { compile_time :>7.1f} s\n "
253+ f" { '─' * 40 } \n "
254+ f" Inference: { generation_time :>7.1f} s\n "
255+ )
256+ if hasattr (out , "timings" ) and out .timings :
257+ timing_str += (
258+ f" Text Encoding: { out .timings .get ('Text Encoding' , 0.0 ):>7.1f} s\n "
259+ f" Preparation: { out .timings .get ('Preparation' , 0.0 ):>7.1f} s\n "
260+ f" Connectors: { out .timings .get ('Connectors' , 0.0 ):>7.1f} s\n "
261+ f" Denoising: { out .timings .get ('Denoising' , 0.0 ):>7.1f} s\n "
262+ )
263+ if out .timings .get ("Latent Upsampler" , 0.0 ) > 0.0 :
264+ timing_str += f" Latent Upsampler: { out .timings .get ('Latent Upsampler' , 0.0 ):>7.1f} s\n "
265+ timing_str += (
266+ f" Latent Processing: { out .timings .get ('Latent Processing' , 0.0 ):>7.1f} s\n "
267+ f" Video VAE: { out .timings .get ('Video VAE' , 0.0 ):>7.1f} s\n "
268+ f" Video Post: { out .timings .get ('Video Post' , 0.0 ):>7.1f} s\n "
269+ f" Audio VAE: { out .timings .get ('Audio VAE' , 0.0 ):>7.1f} s\n "
270+ f" Vocoder: { out .timings .get ('Vocoder' , 0.0 ):>7.1f} s\n "
271+ )
272+ timing_str += f"{ '=' * 50 } "
273+ max_logging .log (timing_str )
227274
228- s0 = time .perf_counter ()
229- if max_utils .profiler_enabled (config ):
230- with max_utils .Profiler (config ):
231- call_pipeline (config , pipeline , prompt , negative_prompt )
232- generation_time_with_profiler = time .perf_counter () - s0
233- max_logging .log (f"generation_time_with_profiler: { generation_time_with_profiler } " )
234- if writer and jax .process_index () == 0 :
235- writer .add_scalar ("inference/generation_time_with_profiler" , generation_time_with_profiler , global_step = 0 )
275+ # Free memory before profiling
276+ del out
277+ del videos
278+ del audios
279+
280+ # ---------------------------------------------------------
281+ # Run 3: Profiling Run (Only if profiling was originally enabled)
282+ # ---------------------------------------------------------
283+ if original_enable_profiler or original_enable_mld :
284+ skip_first_n_steps_for_profiler = config .get_keys ().get ("skip_first_n_steps_for_profiler" , 0 )
285+ if skip_first_n_steps_for_profiler != 0 :
286+ max_logging .log (
287+ "\n ⚠️ WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n "
288+ )
289+
290+ profiling_steps = config .get_keys ().get ("profiler_steps" , 5 )
291+
292+ config .get_keys ()["enable_profiler" ] = False
293+ config .get_keys ()["enable_ml_diagnostics" ] = False
294+ config .get_keys ()["num_inference_steps" ] = profiling_steps
295+
296+ max_logging .log (f"🚀 Warmup for profiling pass ({ profiling_steps } steps)..." )
297+ _ = call_pipeline (config , pipeline , prompt , negative_prompt )
298+
299+ config .get_keys ()["enable_profiler" ] = original_enable_profiler
300+ config .get_keys ()["enable_ml_diagnostics" ] = original_enable_mld
301+
302+ max_logging .log (f"🚀 Starting Profiling run ({ profiling_steps } steps)..." )
303+ profiler = max_utils .Profiler (config , session_name = f"denoise_profile_{ profiling_steps } _steps" )
304+ profiler .start ()
305+
306+ _ = call_pipeline (config , pipeline , prompt , negative_prompt )
307+
308+ profiler .stop ()
236309
237310 return saved_video_path
238311
0 commit comments