Skip to content

Commit 6877724

Browse files
committed
logging added
1 parent e90f77d commit 6877724

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,8 @@ def __call__(
12271227
)
12281228

12291229
# 2. Encode inputs (Text)
1230+
import time
1231+
text_enc_start = time.time()
12301232
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
12311233
prompt,
12321234
negative_prompt,
@@ -1239,6 +1241,8 @@ def __call__(
12391241
max_sequence_length=max_sequence_length,
12401242
dtype=dtype,
12411243
)
1244+
jax.block_until_ready(prompt_embeds)
1245+
max_logging.log(f"⏱️ Text Encoder Time: {time.time() - text_enc_start:.4f} seconds")
12421246

12431247
# 3. Prepare latents
12441248
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
@@ -1378,9 +1382,12 @@ def __call__(
13781382
with context_manager, axis_rules_context:
13791383
connectors_graphdef, connectors_state = nnx.split(self.connectors)
13801384

1385+
connectors_start = time.time()
13811386
video_embeds, audio_embeds, new_attention_mask = self._run_connectors(
13821387
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
13831388
)
1389+
jax.block_until_ready(video_embeds)
1390+
max_logging.log(f"⏱️ Connectors Time: {time.time() - connectors_start:.4f} seconds")
13841391

13851392
video_embeds_sharded = video_embeds
13861393
audio_embeds_sharded = audio_embeds
@@ -1393,6 +1400,7 @@ def __call__(
13931400

13941401
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
13951402

1403+
diffusion_loop_start = time.time()
13961404
scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True)
13971405

13981406
if scan_diffusion_loop:
@@ -1535,7 +1543,11 @@ def convert_to_vel(lat, x0, sig):
15351543
latents_jax = latents_step
15361544
audio_latents_jax = audio_latents_step
15371545

1546+
jax.block_until_ready(latents_jax)
1547+
max_logging.log(f"⏱️ Diffusion Loop Time: {time.time() - diffusion_loop_start:.4f} seconds")
1548+
15381549
# 8. Decode Latents
1550+
decode_start = time.time()
15391551
if do_cfg and do_stg:
15401552
latents_jax = latents_jax[batch_size : 2 * batch_size]
15411553
audio_latents_jax = audio_latents_jax[batch_size : 2 * batch_size]
@@ -1629,6 +1641,7 @@ def convert_to_vel(lat, x0, sig):
16291641
self.transformer = nnx.merge(graphdef, state)
16301642
jax.clear_caches()
16311643

1644+
vae_start = time.time()
16321645
if getattr(self.vae.config, "timestep_conditioning", False):
16331646
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)
16341647

@@ -1650,6 +1663,9 @@ def convert_to_vel(lat, x0, sig):
16501663
latents = latents.astype(self.vae.dtype)
16511664
video = self.vae.decode(latents, return_dict=False)[0]
16521665
# Post-process video (converts to numpy/PIL)
1666+
jax.block_until_ready(video)
1667+
max_logging.log(f"⏱️ Video VAE Decode Time: {time.time() - vae_start:.4f} seconds")
1668+
16531669
# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
16541670
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
16551671
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
@@ -1663,10 +1679,11 @@ def convert_to_vel(lat, x0, sig):
16631679
generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2)
16641680

16651681
vocoder_start_time = time.time()
1666-
audio = self.vocoder(generated_mel_spectrograms)
1682+
# Explicitly JIT compile the vocoder at the call site to guarantee it doesn't run eagerly
1683+
jitted_vocoder = nnx.jit(lambda m, x: m(x))
1684+
audio = jitted_vocoder(self.vocoder, generated_mel_spectrograms)
16671685
jax.block_until_ready(audio)
1668-
vocoder_execution_time = time.time() - vocoder_start_time
1669-
max_logging.log(f"BWE Vocoder Execution Time: {vocoder_execution_time:.4f} seconds")
1686+
max_logging.log(f"⏱️ BWE Vocoder Execution Time: {time.time() - vocoder_start_time:.4f} seconds")
16701687

16711688
# Convert audio to numpy
16721689
audio = np.array(audio)

0 commit comments

Comments
 (0)