perf: optimize LTX2 inference latency and implement granular TPU profiling#389
Open
perf: optimize LTX2 inference latency and implement granular TPU profiling#389
Conversation
- Switched to DP (ici_data_parallelism: -1) in ltx2 config to bypass ICI communication overhead during inference. - Added `jax.named_scope` around connectors and VAE blocks for accurate xprof trace attribution. - Added synchronous `perf_counter` wrappers in the pipeline to measure true stage latencies. - Implemented a 3-pass (warmup, run, profile) generation loop in `generate_ltx2.py` to isolate JIT compilation time and better profiling.
Collaborator
|
@mbohlool Could you add a table with the latency gain (single video and amortized throughput) of this change with the baseline (main)? Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Optimize LTX2 inference latency and implement granular TPU profiling
Description
This PR introduces critical performance optimizations and comprehensive profiling infrastructure for the LTX2 video generation pipeline on TPU hardware.
Key Changes
1. Inference Parallelism Optimization (
ltx2_video.yml)Switched from ICI Context Parallelism (
ici_context_parallelism: 1) to ICI Data Parallelism (ici_data_parallelism: -1).All-GatherICI bottlenecks caused by sequence-sharding.2. Granular XLA Profiling Annotations (
ltx2_pipeline.py)Injected
jax.named_scopewrappers around all major TPU-bound compute blocks (Connectors, Video VAE, Audio VAE, Vocoder).xprof), enabling accurate FLOPs tracking and roofline analysis for individual components outside of the main denoising loop.3. Execution Timing & Benchmarking (
generate_ltx2.py<x2_pipeline.py)Added synchronous
jax.block_until_ready()wrappers at the boundaries of major pipeline stages to accurately measure execution time without asynchronous JAX dispatch artifacts.skip_first_n_steps_for_profiler: 0to completely isolate true runtime execution latency from initial JIT compilation overhead.