From 6bd35bfd78a0165de3721f8831d18fbadaf8f02a Mon Sep 17 00:00:00 2001 From: mbohlool Date: Thu, 23 Apr 2026 00:11:22 +0000 Subject: [PATCH] perf: optimize ltx2 parallelism and implement granular tpu profiling - Switched to DP (ici_data_parallelism: -1) in ltx2 config to bypass ICI communication overhead during inference. - Added `jax.named_scope` around connectors and VAE blocks for accurate xprof trace attribution. - Added synchronous `perf_counter` wrappers in the pipeline to measure true stage latencies. - Implemented a 3-pass (warmup, run, profile) generation loop in `generate_ltx2.py` to isolate JIT compilation time and better profiling. --- src/maxdiffusion/configs/ltx2_video.yml | 9 +- src/maxdiffusion/generate_ltx2.py | 112 +++++++--- .../pipelines/ltx2/ltx2_pipeline.py | 207 +++++++++++------- 3 files changed, 217 insertions(+), 111 deletions(-) diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 9b676d4c..b3011c22 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -79,9 +79,12 @@ flash_block_sizes: { flash_min_seq_length: 4096 dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 -ici_data_parallelism: 1 +# -1 auto-shards the axis. For inference, DP (-1) is recommended over Context Parallelism. +# DP processes independent batch items per core, requiring ZERO cross-core communication. +# Context Parallelism splits the sequence length, causing heavy All-Gather ICI overhead. +ici_data_parallelism: -1 ici_fsdp_parallelism: 1 -ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 enable_profiler: False @@ -89,6 +92,8 @@ enable_profiler: False enable_ml_diagnostics: True profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics" enable_ondemand_xprof: True +skip_first_n_steps_for_profiler: 0 +profiler_steps: 5 replicate_vae: False diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py index d4a356d9..350c663b 100644 --- a/src/maxdiffusion/generate_ltx2.py +++ b/src/maxdiffusion/generate_ltx2.py @@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log("Could not retrieve Git commit hash.") checkpoint_loader = LTX2Checkpointer(config=config) + load_time = 0.0 if pipeline is None: + t0_load = time.perf_counter() # Use the config flag to determine if the upsampler should be loaded run_latent_upsampler = getattr(config, "run_latent_upsampler", False) pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler) @@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): scan_layers=config.scan_layers, dtype=config.weights_dtype, ) + load_time = time.perf_counter() - t0_load pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() @@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - out = call_pipeline(config, pipeline, prompt, negative_prompt) - - # out should have .frames and .audio - videos = out.frames if hasattr(out, "frames") else out[0] - audios = out.audio if hasattr(out, "audio") else None - max_logging.log("===================== Model details =======================") max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}") max_logging.log(f"model path: {config.pretrained_model_name_or_path}") @@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") max_logging.log("============================================================") + original_enable_profiler = config.get_keys().get("enable_profiler", False) + original_enable_mld = config.get_keys().get("enable_ml_diagnostics", False) + original_num_steps = config.get_keys().get("num_inference_steps", 40) + + # --------------------------------------------------------- + # Run 1: Warmup Compilation (Original steps, NO profiling) + # --------------------------------------------------------- + config.get_keys()["enable_profiler"] = False + config.get_keys()["enable_ml_diagnostics"] = False + + max_logging.log(f"šŸš€ Starting warmup compilation pass ({original_num_steps} steps)...") + _ = call_pipeline(config, pipeline, prompt, negative_prompt) + compile_time = time.perf_counter() - s0 max_logging.log(f"compile_time: {compile_time}") if writer and jax.process_index() == 0: writer.add_scalar("inference/compile_time", compile_time, global_step=0) + # --------------------------------------------------------- + # Run 2: Actual Generation (Original steps, NO profiling) + # --------------------------------------------------------- + + s0 = time.perf_counter() + max_logging.log("šŸš€ Starting actual full-length generation pass...") + out = call_pipeline(config, pipeline, prompt, negative_prompt) + generation_time = time.perf_counter() - s0 + max_logging.log(f"generation_time: {generation_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + num_devices = jax.device_count() + num_videos = num_devices * config.per_device_batch_size + if num_videos > 0: + generation_time_per_video = generation_time / num_videos + writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) + max_logging.log(f"generation time per video: {generation_time_per_video}") + else: + max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + + # out should have .frames and .audio + videos = out.frames if hasattr(out, "frames") else out[0] + audios = out.audio if hasattr(out, "audio") else None + saved_video_path = [] audio_sample_rate = ( getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000 @@ -210,29 +244,51 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): if config.output_dir.startswith("gs://"): upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) - s0 = time.perf_counter() - call_pipeline(config, pipeline, prompt, negative_prompt) - generation_time = time.perf_counter() - s0 - max_logging.log(f"generation_time: {generation_time}") - if writer and jax.process_index() == 0: - writer.add_scalar("inference/generation_time", generation_time, global_step=0) - num_devices = jax.device_count() - num_videos = num_devices * config.per_device_batch_size - if num_videos > 0: - generation_time_per_video = generation_time / num_videos - writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) - max_logging.log(f"generation time per video: {generation_time_per_video}") - else: - max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + max_logging.log( + f"\n{'=' * 50}\n" + f" TIMING SUMMARY\n" + f"{'=' * 50}\n" + f" Load (checkpoint): {load_time:>7.1f}s\n" + f" Compile: {compile_time:>7.1f}s\n" + f" {'─' * 40}\n" + f" Inference: {generation_time:>7.1f}s\n" + f"{'=' * 50}" + ) - s0 = time.perf_counter() - if max_utils.profiler_enabled(config): - with max_utils.Profiler(config): - call_pipeline(config, pipeline, prompt, negative_prompt) - generation_time_with_profiler = time.perf_counter() - s0 - max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") - if writer and jax.process_index() == 0: - writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + # Free memory before profiling + del out + del videos + del audios + + # --------------------------------------------------------- + # Run 3: Profiling Run (Only if profiling was originally enabled) + # --------------------------------------------------------- + if original_enable_profiler or original_enable_mld: + skip_first_n_steps_for_profiler = config.get_keys().get("skip_first_n_steps_for_profiler", 0) + if skip_first_n_steps_for_profiler != 0: + max_logging.log( + "\nāš ļø WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n" + ) + + profiling_steps = config.get_keys().get("profiler_steps", 5) + + config.get_keys()["enable_profiler"] = False + config.get_keys()["enable_ml_diagnostics"] = False + config.get_keys()["num_inference_steps"] = profiling_steps + + max_logging.log(f"šŸš€ Warmup for profiling pass ({profiling_steps} steps)...") + _ = call_pipeline(config, pipeline, prompt, negative_prompt) + + config.get_keys()["enable_profiler"] = original_enable_profiler + config.get_keys()["enable_ml_diagnostics"] = original_enable_mld + + max_logging.log(f"šŸš€ Starting Profiling run ({profiling_steps} steps)...") + profiler = max_utils.Profiler(config, session_name=f"denoise_profile_{profiling_steps}_steps") + profiler.start() + + _ = call_pipeline(config, pipeline, prompt, negative_prompt) + + profiler.stop() return saved_video_path diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index a3ec9591..f18bc667 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -17,6 +17,7 @@ from typing import Optional, Any, List, Union from functools import partial +import time import numpy as np import torch import jax @@ -1204,12 +1205,14 @@ def __call__( output_type: str = "pil", return_dict: bool = True, ): + t0_init = time.perf_counter() # 1. Check inputs self.check_inputs( prompt, height, width, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask ) # 2. Encode inputs (Text) + t0_encode = time.perf_counter() prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( prompt, negative_prompt, @@ -1222,6 +1225,8 @@ def __call__( max_sequence_length=max_sequence_length, dtype=dtype, ) + encode_time = time.perf_counter() - t0_encode + max_logging.log(f"Text Encoding time (Gemma-3 on CPU): {encode_time:.2f}s") # 3. Prepare latents batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] @@ -1339,9 +1344,14 @@ def __call__( with context_manager, axis_rules_context: connectors_graphdef, connectors_state = nnx.split(self.connectors) - video_embeds, audio_embeds, new_attention_mask = self._run_connectors( - connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) - ) + t0_connectors = time.perf_counter() + with jax.named_scope("connectors_pass"): + video_embeds, audio_embeds, new_attention_mask = self._run_connectors( + connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) + ) + video_embeds = video_embeds.block_until_ready() + connectors_time = time.perf_counter() - t0_connectors + max_logging.log(f"Connectors pass time: {connectors_time:.2f}s") video_embeds_sharded = video_embeds audio_embeds_sharded = audio_embeds @@ -1350,10 +1360,11 @@ def __call__( activation_axes = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) spec = NamedSharding(self.mesh, P(*activation_axes)) video_embeds_sharded = jax.device_put(video_embeds, spec) - audio_embeds_sharded = jax.device_put(audio_embeds, spec) + audio_embeds_sharded = audio_embeds timesteps_jax = jnp.array(timesteps, dtype=jnp.float32) + t0_denoise = time.perf_counter() scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True) if scan_diffusion_loop: @@ -1438,6 +1449,10 @@ def __call__( latents_jax = latents_step audio_latents_jax = audio_latents_step + jax.block_until_ready(latents_jax) + denoise_time = time.perf_counter() - t0_denoise + max_logging.log(f"Denoising steps time: {denoise_time:.2f}s") + # 8. Decode Latents if guidance_scale > 1.0: latents_jax = latents_jax[batch_size:] @@ -1465,23 +1480,30 @@ def __call__( if getattr(self.config, "run_latent_upsampler", False) and self.latent_upsampler is not None: max_logging.log("šŸš€ Running Latent Upsampler pass...") - if self.latent_upsampler_params is not None: - nnx.update(self.latent_upsampler, self.latent_upsampler_params) - self.latent_upsampler_params = None + upsampler_t0 = time.perf_counter() + with jax.named_scope("upsampler_pass"): + if self.latent_upsampler_params is not None: + nnx.update(self.latent_upsampler, self.latent_upsampler_params) + self.latent_upsampler_params = None - graphdef, state = nnx.split(self.latent_upsampler) + graphdef, state = nnx.split(self.latent_upsampler) - latents_upsampled = self._run_upsampler(graphdef, state, latents) + latents_upsampled = self._run_upsampler(graphdef, state, latents) - adain_factor = getattr(self.config, "upsampler_adain_factor", 0.0) - if adain_factor > 0.0: - latents = adain_filter_latent(latents_upsampled, latents, adain_factor) - else: - latents = latents_upsampled + adain_factor = getattr(self.config, "upsampler_adain_factor", 0.0) + if adain_factor > 0.0: + latents = adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + tone_map_compression = getattr(self.config, "upsampler_tone_map_compression_ratio", 0.0) + if tone_map_compression > 0.0: + latents = tone_map_latents(latents, tone_map_compression) - tone_map_compression = getattr(self.config, "upsampler_tone_map_compression_ratio", 0.0) - if tone_map_compression > 0.0: - latents = tone_map_latents(latents, tone_map_compression) + jax.block_until_ready(latents) + upsampler_time = time.perf_counter() - upsampler_t0 + + max_logging.log(f"Latent Upsampler time: {upsampler_time:.2f}s") # ======================================================================= # Denormalize and Unpack Audio (Order important: Denorm THEN Unpack) @@ -1518,41 +1540,60 @@ def __call__( except Exception as e: max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}") - if getattr(self.vae.config, "timestep_conditioning", False): - noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) + t0_video_vae = time.perf_counter() + with jax.named_scope("video_vae_decode"): + if getattr(self.vae.config, "timestep_conditioning", False): + noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * batch_size - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [decode_noise_scale] * batch_size + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size - timestep = jnp.array(decode_timestep, dtype=latents.dtype) - decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None] + timestep = jnp.array(decode_timestep, dtype=latents.dtype) + decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None] - latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.astype(self.vae.dtype) + video = self.vae.decode(latents, temb=timestep, return_dict=False)[0] + else: + latents = latents.astype(self.vae.dtype) + video = self.vae.decode(latents, return_dict=False)[0] + + video = video.block_until_ready() + video_vae_time = time.perf_counter() - t0_video_vae + max_logging.log(f"Video VAE decode time: {video_vae_time:.2f}s") - latents = latents.astype(self.vae.dtype) - video = self.vae.decode(latents, temb=timestep, return_dict=False)[0] - else: - latents = latents.astype(self.vae.dtype) - video = self.vae.decode(latents, return_dict=False)[0] # Post-process video (converts to numpy/PIL) # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W) + t0_video_post = time.perf_counter() video_np = np.array(video).transpose(0, 4, 1, 2, 3) video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type) + video_post_time = time.perf_counter() - t0_video_post + max_logging.log(f"Video Post-processing time (numpy+PIL): {video_post_time:.2f}s") # Decode Audio - audio_latents = audio_latents.astype(self.audio_vae.dtype) - generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + t0_audio_vae = time.perf_counter() + with jax.named_scope("audio_vae_decode"): + audio_latents = audio_latents.astype(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + generated_mel_spectrograms = generated_mel_spectrograms.block_until_ready() + audio_vae_time = time.perf_counter() - t0_audio_vae + max_logging.log(f"Audio VAE decode time: {audio_vae_time:.2f}s") # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins) - generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2) - audio = self.vocoder(generated_mel_spectrograms) + t0_vocoder = time.perf_counter() + with jax.named_scope("vocoder_pass"): + generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2) + audio = self.vocoder(generated_mel_spectrograms) # Convert audio to numpy audio = np.array(audio) + vocoder_time = time.perf_counter() - t0_vocoder + max_logging.log(f"Vocoder & Audio numpy time: {vocoder_time:.2f}s") return LTX2PipelineOutput(frames=video, audio=audio) @@ -1666,51 +1707,55 @@ def scan_body(carry, t, model): # Expand timestep to batch size t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0]) - noise_pred, noise_pred_audio = model( - hidden_states=latents_sharded, - encoder_hidden_states=video_embeds_sharded, - timestep=t_expanded, - encoder_attention_mask=new_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - audio_hidden_states=audio_latents_sharded, - audio_encoder_hidden_states=audio_embeds_sharded, - audio_encoder_attention_mask=new_attention_mask, - fps=fps, - audio_num_frames=audio_num_frames, - return_dict=False, - ) - - if guidance_scale > 1.0: - noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # Audio guidance - ( - noise_pred_audio_uncond, - noise_pred_audio_text, - ) = jnp.split(noise_pred_audio, 2, axis=0) - noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) - - latents_step = latents[batch_size:] - audio_latents_step = audio_latents[batch_size:] - else: - latents_step = latents - audio_latents_step = audio_latents - - # Step scheduler - latents_step, _ = scheduler_step(s_state, noise_pred, t, latents_step, return_dict=False) - latents_step = latents_step.astype(latents.dtype) - - audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False) - audio_latents_step = audio_latents_step.astype(audio_latents.dtype) + with jax.named_scope("transformer_forward_pass"): + noise_pred, noise_pred_audio = model( + hidden_states=latents_sharded, + encoder_hidden_states=video_embeds_sharded, + timestep=t_expanded, + encoder_attention_mask=new_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + audio_hidden_states=audio_latents_sharded, + audio_encoder_hidden_states=audio_embeds_sharded, + audio_encoder_attention_mask=new_attention_mask, + fps=fps, + audio_num_frames=audio_num_frames, + return_dict=False, + ) - if guidance_scale > 1.0: - latents_next = jnp.concatenate([latents_step] * 2, axis=0) - audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0) - else: - latents_next = latents_step - audio_latents_next = audio_latents_step + with jax.named_scope("classifier_free_guidance"): + if guidance_scale > 1.0: + noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Audio guidance + ( + noise_pred_audio_uncond, + noise_pred_audio_text, + ) = jnp.split(noise_pred_audio, 2, axis=0) + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + latents_step = latents[batch_size:] + audio_latents_step = audio_latents[batch_size:] + else: + latents_step = latents + audio_latents_step = audio_latents + + with jax.named_scope("scheduler_step"): + # Step scheduler + latents_step, _ = scheduler_step(s_state, noise_pred, t, latents_step, return_dict=False) + latents_step = latents_step.astype(latents.dtype) + + audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False) + audio_latents_step = audio_latents_step.astype(audio_latents.dtype) + + with jax.named_scope("latent_concatenation"): + if guidance_scale > 1.0: + latents_next = jnp.concatenate([latents_step] * 2, axis=0) + audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0) + else: + latents_next = latents_step + audio_latents_next = audio_latents_step new_carry = (latents_next, audio_latents_next, s_state) return new_carry, None