Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,21 @@ flash_block_sizes: {
flash_min_seq_length: 4096
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
# -1 auto-shards the axis. For inference, DP (-1) is recommended over Context Parallelism.
# DP processes independent batch items per core, requiring ZERO cross-core communication.
# Context Parallelism splits the sequence length, causing heavy All-Gather ICI overhead.
ici_data_parallelism: -1
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1
enable_profiler: False

# ML Diagnostics settings
enable_ml_diagnostics: True
profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics"
enable_ondemand_xprof: True
skip_first_n_steps_for_profiler: 0
profiler_steps: 5

replicate_vae: False

Expand Down
112 changes: 84 additions & 28 deletions src/maxdiffusion/generate_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log("Could not retrieve Git commit hash.")

checkpoint_loader = LTX2Checkpointer(config=config)
load_time = 0.0
if pipeline is None:
t0_load = time.perf_counter()
# Use the config flag to determine if the upsampler should be loaded
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)
Expand Down Expand Up @@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
scan_layers=config.scan_layers,
dtype=config.weights_dtype,
)
load_time = time.perf_counter() - t0_load

pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
Expand All @@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
)

out = call_pipeline(config, pipeline, prompt, negative_prompt)

# out should have .frames and .audio
videos = out.frames if hasattr(out, "frames") else out[0]
audios = out.audio if hasattr(out, "audio") else None

max_logging.log("===================== Model details =======================")
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
Expand All @@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
max_logging.log("============================================================")

original_enable_profiler = config.get_keys().get("enable_profiler", False)
original_enable_mld = config.get_keys().get("enable_ml_diagnostics", False)
original_num_steps = config.get_keys().get("num_inference_steps", 40)

# ---------------------------------------------------------
# Run 1: Warmup Compilation (Original steps, NO profiling)
# ---------------------------------------------------------
config.get_keys()["enable_profiler"] = False
config.get_keys()["enable_ml_diagnostics"] = False

max_logging.log(f"🚀 Starting warmup compilation pass ({original_num_steps} steps)...")
_ = call_pipeline(config, pipeline, prompt, negative_prompt)

compile_time = time.perf_counter() - s0
max_logging.log(f"compile_time: {compile_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/compile_time", compile_time, global_step=0)

# ---------------------------------------------------------
# Run 2: Actual Generation (Original steps, NO profiling)
# ---------------------------------------------------------

s0 = time.perf_counter()
max_logging.log("🚀 Starting actual full-length generation pass...")
out = call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time = time.perf_counter() - s0
max_logging.log(f"generation_time: {generation_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
num_devices = jax.device_count()
num_videos = num_devices * config.per_device_batch_size
if num_videos > 0:
generation_time_per_video = generation_time / num_videos
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
max_logging.log(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")

# out should have .frames and .audio
videos = out.frames if hasattr(out, "frames") else out[0]
audios = out.audio if hasattr(out, "audio") else None

saved_video_path = []
audio_sample_rate = (
getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000
Expand All @@ -210,29 +244,51 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
if config.output_dir.startswith("gs://"):
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)

s0 = time.perf_counter()
call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time = time.perf_counter() - s0
max_logging.log(f"generation_time: {generation_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
num_devices = jax.device_count()
num_videos = num_devices * config.per_device_batch_size
if num_videos > 0:
generation_time_per_video = generation_time / num_videos
writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0)
max_logging.log(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
max_logging.log(
f"\n{'=' * 50}\n"
f" TIMING SUMMARY\n"
f"{'=' * 50}\n"
f" Load (checkpoint): {load_time:>7.1f}s\n"
f" Compile: {compile_time:>7.1f}s\n"
f" {'─' * 40}\n"
f" Inference: {generation_time:>7.1f}s\n"
f"{'=' * 50}"
)

s0 = time.perf_counter()
if max_utils.profiler_enabled(config):
with max_utils.Profiler(config):
call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time_with_profiler = time.perf_counter() - s0
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
# Free memory before profiling
del out
del videos
del audios

# ---------------------------------------------------------
# Run 3: Profiling Run (Only if profiling was originally enabled)
# ---------------------------------------------------------
if original_enable_profiler or original_enable_mld:
skip_first_n_steps_for_profiler = config.get_keys().get("skip_first_n_steps_for_profiler", 0)
if skip_first_n_steps_for_profiler != 0:
max_logging.log(
"\n⚠️ WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n"
)

profiling_steps = config.get_keys().get("profiler_steps", 5)

config.get_keys()["enable_profiler"] = False
config.get_keys()["enable_ml_diagnostics"] = False
config.get_keys()["num_inference_steps"] = profiling_steps

max_logging.log(f"🚀 Warmup for profiling pass ({profiling_steps} steps)...")
_ = call_pipeline(config, pipeline, prompt, negative_prompt)

config.get_keys()["enable_profiler"] = original_enable_profiler
config.get_keys()["enable_ml_diagnostics"] = original_enable_mld

max_logging.log(f"🚀 Starting Profiling run ({profiling_steps} steps)...")
profiler = max_utils.Profiler(config, session_name=f"denoise_profile_{profiling_steps}_steps")
profiler.start()

_ = call_pipeline(config, pipeline, prompt, negative_prompt)

profiler.stop()

return saved_video_path

Expand Down
Loading
Loading