Skip to content

Commit 6bd35bf

Browse files
committed
perf: optimize ltx2 parallelism and implement granular tpu profiling
- Switched to DP (ici_data_parallelism: -1) in ltx2 config to bypass ICI communication overhead during inference. - Added `jax.named_scope` around connectors and VAE blocks for accurate xprof trace attribution. - Added synchronous `perf_counter` wrappers in the pipeline to measure true stage latencies. - Implemented a 3-pass (warmup, run, profile) generation loop in `generate_ltx2.py` to isolate JIT compilation time and better profiling.
1 parent ae22683 commit 6bd35bf

3 files changed

Lines changed: 217 additions & 111 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,21 @@ flash_block_sizes: {
7979
flash_min_seq_length: 4096
8080
dcn_context_parallelism: 1
8181
dcn_tensor_parallelism: 1
82-
ici_data_parallelism: 1
82+
# -1 auto-shards the axis. For inference, DP (-1) is recommended over Context Parallelism.
83+
# DP processes independent batch items per core, requiring ZERO cross-core communication.
84+
# Context Parallelism splits the sequence length, causing heavy All-Gather ICI overhead.
85+
ici_data_parallelism: -1
8386
ici_fsdp_parallelism: 1
84-
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
87+
ici_context_parallelism: 1
8588
ici_tensor_parallelism: 1
8689
enable_profiler: False
8790

8891
# ML Diagnostics settings
8992
enable_ml_diagnostics: True
9093
profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics"
9194
enable_ondemand_xprof: True
95+
skip_first_n_steps_for_profiler: 0
96+
profiler_steps: 5
9297

9398
replicate_vae: False
9499

src/maxdiffusion/generate_ltx2.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
116116
max_logging.log("Could not retrieve Git commit hash.")
117117

118118
checkpoint_loader = LTX2Checkpointer(config=config)
119+
load_time = 0.0
119120
if pipeline is None:
121+
t0_load = time.perf_counter()
120122
# Use the config flag to determine if the upsampler should be loaded
121123
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
122124
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
@@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
145147
scan_layers=config.scan_layers,
146148
dtype=config.weights_dtype,
147149
)
150+
load_time = time.perf_counter() - t0_load
148151

149152
pipeline.enable_vae_slicing()
150153
pipeline.enable_vae_tiling()
@@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
162165
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
163166
)
164167

165-
out = call_pipeline(config, pipeline, prompt, negative_prompt)
166-
167-
# out should have .frames and .audio
168-
videos = out.frames if hasattr(out, "frames") else out[0]
169-
audios = out.audio if hasattr(out, "audio") else None
170-
171168
max_logging.log("===================== Model details =======================")
172169
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
173170
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
@@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
179176
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
180177
max_logging.log("============================================================")
181178

179+
original_enable_profiler = config.get_keys().get("enable_profiler", False)
180+
original_enable_mld = config.get_keys().get("enable_ml_diagnostics", False)
181+
original_num_steps = config.get_keys().get("num_inference_steps", 40)
182+
183+
# ---------------------------------------------------------
184+
# Run 1: Warmup Compilation (Original steps, NO profiling)
185+
# ---------------------------------------------------------
186+
config.get_keys()["enable_profiler"] = False
187+
config.get_keys()["enable_ml_diagnostics"] = False
188+
189+
max_logging.log(f"🚀 Starting warmup compilation pass ({original_num_steps} steps)...")
190+
_ = call_pipeline(config, pipeline, prompt, negative_prompt)
191+
182192
compile_time = time.perf_counter() - s0
183193
max_logging.log(f"compile_time: {compile_time}")
184194
if writer and jax.process_index() == 0:
185195
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
186196

197+
# ---------------------------------------------------------
198+
# Run 2: Actual Generation (Original steps, NO profiling)
199+
# ---------------------------------------------------------
200+
201+
s0 = time.perf_counter()
202+
max_logging.log("🚀 Starting actual full-length generation pass...")
203+
out = call_pipeline(config, pipeline, prompt, negative_prompt)
204+
generation_time = time.perf_counter() - s0
205+
max_logging.log(f"generation_time: {generation_time}")
206+
if writer and jax.process_index() == 0:
207+
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
208+
num_devices = jax.device_count()
209+
num_videos = num_devices * config.per_device_batch_size
210+
if num_videos > 0:
211+
generation_time_per_video = generation_time / num_videos
212+
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
213+
max_logging.log(f"generation time per video: {generation_time_per_video}")
214+
else:
215+
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
216+
217+
# out should have .frames and .audio
218+
videos = out.frames if hasattr(out, "frames") else out[0]
219+
audios = out.audio if hasattr(out, "audio") else None
220+
187221
saved_video_path = []
188222
audio_sample_rate = (
189223
getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000
@@ -210,29 +244,51 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
210244
if config.output_dir.startswith("gs://"):
211245
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
212246

213-
s0 = time.perf_counter()
214-
call_pipeline(config, pipeline, prompt, negative_prompt)
215-
generation_time = time.perf_counter() - s0
216-
max_logging.log(f"generation_time: {generation_time}")
217-
if writer and jax.process_index() == 0:
218-
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
219-
num_devices = jax.device_count()
220-
num_videos = num_devices * config.per_device_batch_size
221-
if num_videos > 0:
222-
generation_time_per_video = generation_time / num_videos
223-
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
224-
max_logging.log(f"generation time per video: {generation_time_per_video}")
225-
else:
226-
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
247+
max_logging.log(
248+
f"\n{'=' * 50}\n"
249+
f" TIMING SUMMARY\n"
250+
f"{'=' * 50}\n"
251+
f" Load (checkpoint): {load_time:>7.1f}s\n"
252+
f" Compile: {compile_time:>7.1f}s\n"
253+
f" {'─' * 40}\n"
254+
f" Inference: {generation_time:>7.1f}s\n"
255+
f"{'=' * 50}"
256+
)
227257

228-
s0 = time.perf_counter()
229-
if max_utils.profiler_enabled(config):
230-
with max_utils.Profiler(config):
231-
call_pipeline(config, pipeline, prompt, negative_prompt)
232-
generation_time_with_profiler = time.perf_counter() - s0
233-
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
234-
if writer and jax.process_index() == 0:
235-
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
258+
# Free memory before profiling
259+
del out
260+
del videos
261+
del audios
262+
263+
# ---------------------------------------------------------
264+
# Run 3: Profiling Run (Only if profiling was originally enabled)
265+
# ---------------------------------------------------------
266+
if original_enable_profiler or original_enable_mld:
267+
skip_first_n_steps_for_profiler = config.get_keys().get("skip_first_n_steps_for_profiler", 0)
268+
if skip_first_n_steps_for_profiler != 0:
269+
max_logging.log(
270+
"\n⚠️ WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n"
271+
)
272+
273+
profiling_steps = config.get_keys().get("profiler_steps", 5)
274+
275+
config.get_keys()["enable_profiler"] = False
276+
config.get_keys()["enable_ml_diagnostics"] = False
277+
config.get_keys()["num_inference_steps"] = profiling_steps
278+
279+
max_logging.log(f"🚀 Warmup for profiling pass ({profiling_steps} steps)...")
280+
_ = call_pipeline(config, pipeline, prompt, negative_prompt)
281+
282+
config.get_keys()["enable_profiler"] = original_enable_profiler
283+
config.get_keys()["enable_ml_diagnostics"] = original_enable_mld
284+
285+
max_logging.log(f"🚀 Starting Profiling run ({profiling_steps} steps)...")
286+
profiler = max_utils.Profiler(config, session_name=f"denoise_profile_{profiling_steps}_steps")
287+
profiler.start()
288+
289+
_ = call_pipeline(config, pipeline, prompt, negative_prompt)
290+
291+
profiler.stop()
236292

237293
return saved_video_path
238294

0 commit comments

Comments
 (0)