@@ -1227,6 +1227,8 @@ def __call__(
12271227 )
12281228
12291229 # 2. Encode inputs (Text)
1230+ import time
1231+ text_enc_start = time .time ()
12301232 prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask = self .encode_prompt (
12311233 prompt ,
12321234 negative_prompt ,
@@ -1239,6 +1241,8 @@ def __call__(
12391241 max_sequence_length = max_sequence_length ,
12401242 dtype = dtype ,
12411243 )
1244+ jax .block_until_ready (prompt_embeds )
1245+ max_logging .log (f"⏱️ Text Encoder Time: { time .time () - text_enc_start :.4f} seconds" )
12421246
12431247 # 3. Prepare latents
12441248 batch_size = prompt_embeds [0 ].shape [0 ] if isinstance (prompt_embeds , list ) else prompt_embeds .shape [0 ]
@@ -1378,9 +1382,12 @@ def __call__(
13781382 with context_manager , axis_rules_context :
13791383 connectors_graphdef , connectors_state = nnx .split (self .connectors )
13801384
1385+ connectors_start = time .time ()
13811386 video_embeds , audio_embeds , new_attention_mask = self ._run_connectors (
13821387 connectors_graphdef , connectors_state , prompt_embeds_jax , prompt_attention_mask_jax .astype (jnp .bool_ )
13831388 )
1389+ jax .block_until_ready (video_embeds )
1390+ max_logging .log (f"⏱️ Connectors Time: { time .time () - connectors_start :.4f} seconds" )
13841391
13851392 video_embeds_sharded = video_embeds
13861393 audio_embeds_sharded = audio_embeds
@@ -1393,6 +1400,7 @@ def __call__(
13931400
13941401 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
13951402
1403+ diffusion_loop_start = time .time ()
13961404 scan_diffusion_loop = getattr (self .config , "scan_diffusion_loop" , True )
13971405
13981406 if scan_diffusion_loop :
@@ -1535,7 +1543,11 @@ def convert_to_vel(lat, x0, sig):
15351543 latents_jax = latents_step
15361544 audio_latents_jax = audio_latents_step
15371545
1546+ jax .block_until_ready (latents_jax )
1547+ max_logging .log (f"⏱️ Diffusion Loop Time: { time .time () - diffusion_loop_start :.4f} seconds" )
1548+
15381549 # 8. Decode Latents
1550+ decode_start = time .time ()
15391551 if do_cfg and do_stg :
15401552 latents_jax = latents_jax [batch_size : 2 * batch_size ]
15411553 audio_latents_jax = audio_latents_jax [batch_size : 2 * batch_size ]
@@ -1629,6 +1641,7 @@ def convert_to_vel(lat, x0, sig):
16291641 self .transformer = nnx .merge (graphdef , state )
16301642 jax .clear_caches ()
16311643
1644+ vae_start = time .time ()
16321645 if getattr (self .vae .config , "timestep_conditioning" , False ):
16331646 noise = jax .random .normal (generator , latents .shape , dtype = latents .dtype )
16341647
@@ -1650,6 +1663,9 @@ def convert_to_vel(lat, x0, sig):
16501663 latents = latents .astype (self .vae .dtype )
16511664 video = self .vae .decode (latents , return_dict = False )[0 ]
16521665 # Post-process video (converts to numpy/PIL)
1666+ jax .block_until_ready (video )
1667+ max_logging .log (f"⏱️ Video VAE Decode Time: { time .time () - vae_start :.4f} seconds" )
1668+
16531669 # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
16541670 video_np = np .array (video ).transpose (0 , 4 , 1 , 2 , 3 )
16551671 video = self .video_processor .postprocess_video (torch .from_numpy (video_np ), output_type = output_type )
@@ -1663,10 +1679,11 @@ def convert_to_vel(lat, x0, sig):
16631679 generated_mel_spectrograms = generated_mel_spectrograms .transpose (0 , 3 , 1 , 2 )
16641680
16651681 vocoder_start_time = time .time ()
1666- audio = self .vocoder (generated_mel_spectrograms )
1682+ # Explicitly JIT compile the vocoder at the call site to guarantee it doesn't run eagerly
1683+ jitted_vocoder = nnx .jit (lambda m , x : m (x ))
1684+ audio = jitted_vocoder (self .vocoder , generated_mel_spectrograms )
16671685 jax .block_until_ready (audio )
1668- vocoder_execution_time = time .time () - vocoder_start_time
1669- max_logging .log (f"BWE Vocoder Execution Time: { vocoder_execution_time :.4f} seconds" )
1686+ max_logging .log (f"⏱️ BWE Vocoder Execution Time: { time .time () - vocoder_start_time :.4f} seconds" )
16701687
16711688 # Convert audio to numpy
16721689 audio = np .array (audio )
0 commit comments