@@ -833,9 +833,6 @@ def _print_stats(name, tensor):
833833 print (
834834 f"DEBUG { name } shape: { tensor .shape } , mean: { jnp .round (jnp .mean (tensor ), 4 )} , min: { jnp .round (jnp .min (tensor ), 4 )} , max: { jnp .round (jnp .max (tensor ), 4 )} , std: { jnp .round (jnp .std (tensor ), 4 )} "
835835 )
836- _print_stats ("text_encoder_output_layer_0" , prompt_embeds_list [0 ])
837- _print_stats ("text_encoder_output_layer_last" , prompt_embeds_list [- 1 ])
838-
839836 prompt_embeds = prompt_embeds_list
840837 del text_encoder_hidden_states # Free PyTorch tensor memory
841838
@@ -1949,8 +1946,7 @@ def scan_body(carry, inputs):
19491946 )
19501947
19511948 is_first_step = (t == timesteps_jax [0 ])
1952- jax .lax .cond (is_first_step , lambda : print_stats_jit ("noise_pred" , noise_pred ), lambda : None )
1953- jax .lax .cond (is_first_step , lambda : print_stats_jit ("noise_pred_audio" , noise_pred_audio ), lambda : None )
1949+
19541950
19551951 # Extract latents_step based on stacking strategy
19561952 if do_cfg and do_stg :
@@ -1984,6 +1980,7 @@ def scan_body(carry, inputs):
19841980 x0_combined = rescale_noise_cfg (x0_combined , x0_text , guidance_rescale = guidance_rescale )
19851981
19861982 noise_pred = convert_to_vel (latents_step , x0_combined , sigma_t )
1983+ jax .lax .cond (is_first_step , lambda : print_stats_jit ("noise_pred_video_after_guidance" , noise_pred ), lambda : None )
19871984
19881985 # Audio guidance
19891986 noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb , noise_pred_audio_isolated = jnp .split (noise_pred_audio , 4 , axis = 0 )
@@ -2003,6 +2000,7 @@ def scan_body(carry, inputs):
20032000 x0_audio_combined = rescale_noise_cfg (x0_audio_combined , x0_audio_text , guidance_rescale = audio_guidance_rescale )
20042001
20052002 noise_pred_audio = convert_to_vel (audio_latents_step , x0_audio_combined , sigma_t )
2003+ jax .lax .cond (is_first_step , lambda : print_stats_jit ("noise_pred_audio_after_guidance" , noise_pred_audio ), lambda : None )
20062004
20072005 # ... (Standard CFG paths can be added here, but for brevity and since LTX2.3 runs with STG this handles the core logic)
20082006 elif do_cfg :
0 commit comments