Skip to content

Commit 6d003f0

Browse files
Merge pull request #389 from AI-Hypercomputer:mehdy_perf
PiperOrigin-RevId: 910762117
2 parents bb3b0c6 + 9d73dd7 commit 6d003f0

3 files changed

Lines changed: 259 additions & 110 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ enable_profiler: False
8989
enable_ml_diagnostics: True
9090
profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics"
9191
enable_ondemand_xprof: True
92+
skip_first_n_steps_for_profiler: 0
93+
profiler_steps: 5
9294

9395
replicate_vae: False
9496

src/maxdiffusion/generate_ltx2.py

Lines changed: 101 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
116116
max_logging.log("Could not retrieve Git commit hash.")
117117

118118
checkpoint_loader = LTX2Checkpointer(config=config)
119+
load_time = 0.0
119120
if pipeline is None:
121+
t0_load = time.perf_counter()
120122
# Use the config flag to determine if the upsampler should be loaded
121123
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
122124
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
@@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
145147
scan_layers=config.scan_layers,
146148
dtype=config.weights_dtype,
147149
)
150+
load_time = time.perf_counter() - t0_load
148151

149152
pipeline.enable_vae_slicing()
150153
pipeline.enable_vae_tiling()
@@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
162165
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
163166
)
164167

165-
out = call_pipeline(config, pipeline, prompt, negative_prompt)
166-
167-
# out should have .frames and .audio
168-
videos = out.frames if hasattr(out, "frames") else out[0]
169-
audios = out.audio if hasattr(out, "audio") else None
170-
171168
max_logging.log("===================== Model details =======================")
172169
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
173170
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
@@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
179176
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
180177
max_logging.log("============================================================")
181178

179+
original_enable_profiler = config.get_keys().get("enable_profiler", False)
180+
original_enable_mld = config.get_keys().get("enable_ml_diagnostics", False)
181+
original_num_steps = config.get_keys().get("num_inference_steps", 40)
182+
183+
# ---------------------------------------------------------
184+
# Run 1: Warmup Compilation (Original steps, NO profiling)
185+
# ---------------------------------------------------------
186+
config.get_keys()["enable_profiler"] = False
187+
config.get_keys()["enable_ml_diagnostics"] = False
188+
189+
max_logging.log(f"🚀 Starting warmup compilation pass ({original_num_steps} steps)...")
190+
_ = call_pipeline(config, pipeline, prompt, negative_prompt)
191+
182192
compile_time = time.perf_counter() - s0
183193
max_logging.log(f"compile_time: {compile_time}")
184194
if writer and jax.process_index() == 0:
185195
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
186196

197+
# ---------------------------------------------------------
198+
# Run 2: Actual Generation (Original steps, NO profiling)
199+
# ---------------------------------------------------------
200+
201+
s0 = time.perf_counter()
202+
max_logging.log("🚀 Starting actual full-length generation pass...")
203+
out = call_pipeline(config, pipeline, prompt, negative_prompt)
204+
generation_time = time.perf_counter() - s0
205+
max_logging.log(f"generation_time: {generation_time}")
206+
if writer and jax.process_index() == 0:
207+
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
208+
num_devices = jax.device_count()
209+
num_videos = num_devices * config.per_device_batch_size
210+
if num_videos > 0:
211+
generation_time_per_video = generation_time / num_videos
212+
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
213+
max_logging.log(f"generation time per video: {generation_time_per_video}")
214+
else:
215+
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
216+
217+
# out should have .frames and .audio
218+
videos = out.frames if hasattr(out, "frames") else out[0]
219+
audios = out.audio if hasattr(out, "audio") else None
220+
187221
saved_video_path = []
188222
audio_sample_rate = (
189223
getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000
@@ -210,29 +244,68 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
210244
if config.output_dir.startswith("gs://"):
211245
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
212246

213-
s0 = time.perf_counter()
214-
call_pipeline(config, pipeline, prompt, negative_prompt)
215-
generation_time = time.perf_counter() - s0
216-
max_logging.log(f"generation_time: {generation_time}")
217-
if writer and jax.process_index() == 0:
218-
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
219-
num_devices = jax.device_count()
220-
num_videos = num_devices * config.per_device_batch_size
221-
if num_videos > 0:
222-
generation_time_per_video = generation_time / num_videos
223-
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
224-
max_logging.log(f"generation time per video: {generation_time_per_video}")
225-
else:
226-
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
247+
timing_str = (
248+
f"\n{'=' * 50}\n"
249+
f" TIMING SUMMARY\n"
250+
f"{'=' * 50}\n"
251+
f" Load (checkpoint): {load_time:>7.1f}s\n"
252+
f" Compile: {compile_time:>7.1f}s\n"
253+
f" {'─' * 40}\n"
254+
f" Inference: {generation_time:>7.1f}s\n"
255+
)
256+
if hasattr(out, "timings") and out.timings:
257+
timing_str += (
258+
f" Text Encoding: {out.timings.get('Text Encoding', 0.0):>7.1f}s\n"
259+
f" Preparation: {out.timings.get('Preparation', 0.0):>7.1f}s\n"
260+
f" Connectors: {out.timings.get('Connectors', 0.0):>7.1f}s\n"
261+
f" Denoising: {out.timings.get('Denoising', 0.0):>7.1f}s\n"
262+
)
263+
if out.timings.get("Latent Upsampler", 0.0) > 0.0:
264+
timing_str += f" Latent Upsampler: {out.timings.get('Latent Upsampler', 0.0):>7.1f}s\n"
265+
timing_str += (
266+
f" Latent Processing: {out.timings.get('Latent Processing', 0.0):>7.1f}s\n"
267+
f" Video VAE: {out.timings.get('Video VAE', 0.0):>7.1f}s\n"
268+
f" Video Post: {out.timings.get('Video Post', 0.0):>7.1f}s\n"
269+
f" Audio VAE: {out.timings.get('Audio VAE', 0.0):>7.1f}s\n"
270+
f" Vocoder: {out.timings.get('Vocoder', 0.0):>7.1f}s\n"
271+
)
272+
timing_str += f"{'=' * 50}"
273+
max_logging.log(timing_str)
227274

228-
s0 = time.perf_counter()
229-
if max_utils.profiler_enabled(config):
230-
with max_utils.Profiler(config):
231-
call_pipeline(config, pipeline, prompt, negative_prompt)
232-
generation_time_with_profiler = time.perf_counter() - s0
233-
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
234-
if writer and jax.process_index() == 0:
235-
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
275+
# Free memory before profiling
276+
del out
277+
del videos
278+
del audios
279+
280+
# ---------------------------------------------------------
281+
# Run 3: Profiling Run (Only if profiling was originally enabled)
282+
# ---------------------------------------------------------
283+
if original_enable_profiler or original_enable_mld:
284+
skip_first_n_steps_for_profiler = config.get_keys().get("skip_first_n_steps_for_profiler", 0)
285+
if skip_first_n_steps_for_profiler != 0:
286+
max_logging.log(
287+
"\n⚠️ WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n"
288+
)
289+
290+
profiling_steps = config.get_keys().get("profiler_steps", 5)
291+
292+
config.get_keys()["enable_profiler"] = False
293+
config.get_keys()["enable_ml_diagnostics"] = False
294+
config.get_keys()["num_inference_steps"] = profiling_steps
295+
296+
max_logging.log(f"🚀 Warmup for profiling pass ({profiling_steps} steps)...")
297+
_ = call_pipeline(config, pipeline, prompt, negative_prompt)
298+
299+
config.get_keys()["enable_profiler"] = original_enable_profiler
300+
config.get_keys()["enable_ml_diagnostics"] = original_enable_mld
301+
302+
max_logging.log(f"🚀 Starting Profiling run ({profiling_steps} steps)...")
303+
profiler = max_utils.Profiler(config, session_name=f"denoise_profile_{profiling_steps}_steps")
304+
profiler.start()
305+
306+
_ = call_pipeline(config, pipeline, prompt, negative_prompt)
307+
308+
profiler.stop()
236309

237310
return saved_video_path
238311

0 commit comments

Comments
 (0)