@@ -1398,6 +1398,14 @@ def __call__(
13981398 prompt_attention_mask_jax = jnp .concatenate ([negative_prompt_attention_mask_jax , prompt_attention_mask_jax ], axis = 0 )
13991399 latents_jax = jnp .concatenate ([latents_jax ] * 2 , axis = 0 )
14001400 audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 2 , axis = 0 )
1401+ def _print_stats_gemma (name , tensor ):
1402+ print (
1403+ f"DEBUG { name } shape: { tensor .shape } , mean: { jnp .round (jnp .mean (tensor ), 6 )} , min: { jnp .round (jnp .min (tensor ), 4 )} , max: { jnp .round (jnp .max (tensor ), 4 )} , std: { jnp .round (jnp .std (tensor ), 4 )} "
1404+ )
1405+ if do_cfg and do_stg :
1406+ _print_stats_gemma ("text_encoder_output_flattened" , prompt_embeds_jax [:2 ])
1407+ else :
1408+ _print_stats_gemma ("text_encoder_output_flattened" , prompt_embeds_jax )
14011409
14021410 if hasattr (self , "mesh" ) and self .mesh is not None :
14031411 data_sharding_3d = NamedSharding (self .mesh , P ())
@@ -1444,19 +1452,31 @@ def __call__(
14441452 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
14451453 def _print_stats (name , tensor ):
14461454 print (
1447- f"DEBUG { name } shape: { tensor .shape } , mean: { jnp .round (jnp .mean (tensor ), 4 )} , min: { jnp .round (jnp .min (tensor ), 4 )} , max: { jnp .round (jnp .max (tensor ), 4 )} , std: { jnp .round (jnp .std (tensor ), 4 )} "
1455+ f"DEBUG { name } shape: { tensor .shape } , mean: { jnp .round (jnp .mean (tensor ), 6 )} , min: { jnp .round (jnp .min (tensor ), 4 )} , max: { jnp .round (jnp .max (tensor ), 4 )} , std: { jnp .round (jnp .std (tensor ), 4 )} "
14481456 )
14491457 print (f"WEIGHT DEBUG: block 0 to_q kernel mean: { float (self .transformer .transformer_blocks .attn1 .to_q .kernel .value [0 ].mean ()):.6f} " )
1450- _print_stats ("video_embeds" , video_embeds )
1451- _print_stats ("audio_embeds" , audio_embeds )
1458+ if do_cfg and do_stg :
1459+ _print_stats ("video_text_embedding" , video_embeds [:2 ])
1460+ _print_stats ("audio_text_embedding" , audio_embeds [:2 ])
1461+ else :
1462+ _print_stats ("video_text_embedding" , video_embeds )
1463+ _print_stats ("audio_text_embedding" , audio_embeds )
14521464
14531465 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
14541466
14551467 diffusion_loop_start = time .time ()
14561468 scan_diffusion_loop = getattr (self .config , "scan_diffusion_loop" , True )
14571469
1458- _print_stats ("latents_jax_before_loop" , latents_jax )
1459- _print_stats ("audio_latents_jax_before_loop" , audio_latents_jax )
1470+ if do_cfg and do_stg :
1471+ _print_stats ("latents_jax_before_loop" , latents_jax [:batch_size ])
1472+ _print_stats ("audio_latents_jax_before_loop" , audio_latents_jax [:batch_size ])
1473+ elif do_cfg :
1474+ _print_stats ("latents_jax_before_loop" , latents_jax [:batch_size ])
1475+ _print_stats ("audio_latents_jax_before_loop" , audio_latents_jax [:batch_size ])
1476+ else :
1477+ _print_stats ("latents_jax_before_loop" , latents_jax )
1478+ _print_stats ("audio_latents_jax_before_loop" , audio_latents_jax )
1479+
14601480 if scan_diffusion_loop :
14611481 latents_jax , audio_latents_jax = run_diffusion_loop (
14621482 graphdef ,
@@ -1948,6 +1968,21 @@ def scan_body(carry, inputs):
19481968
19491969 is_first_step = (t == timesteps_jax [0 ])
19501970
1971+ def print_raw_stats ():
1972+ print_stats_jit ("noise_pred_video_raw" , noise_pred )
1973+ print_stats_jit ("noise_pred_audio_raw" , noise_pred_audio )
1974+ if do_cfg :
1975+ uncond_v = noise_pred [:batch_size ]
1976+ cond_v = noise_pred [batch_size : 2 * batch_size ]
1977+ uncond_a = noise_pred_audio [:batch_size ]
1978+ cond_a = noise_pred_audio [batch_size : 2 * batch_size ]
1979+ print_stats_jit ("noise_pred_video_raw_uncond" , uncond_v )
1980+ print_stats_jit ("noise_pred_video_raw_cond" , cond_v )
1981+ print_stats_jit ("noise_pred_audio_raw_uncond" , uncond_a )
1982+ print_stats_jit ("noise_pred_audio_raw_cond" , cond_a )
1983+
1984+ jax .lax .cond (is_first_step , print_raw_stats , lambda : None )
1985+
19511986
19521987 # Extract latents_step based on stacking strategy
19531988 if do_cfg and do_stg :
@@ -1981,7 +2016,6 @@ def scan_body(carry, inputs):
19812016 x0_combined = rescale_noise_cfg (x0_combined , x0_text , guidance_rescale = guidance_rescale )
19822017
19832018 noise_pred = convert_to_vel (latents_step , x0_combined , sigma_t )
1984- jax .lax .cond (is_first_step , lambda : print_stats_jit ("noise_pred_video_after_guidance" , noise_pred ), lambda : None )
19852019
19862020 # Audio guidance
19872021 noise_pred_audio_uncond , noise_pred_audio_text , noise_pred_audio_perturb , noise_pred_audio_isolated = jnp .split (noise_pred_audio , 4 , axis = 0 )
@@ -2001,7 +2035,6 @@ def scan_body(carry, inputs):
20012035 x0_audio_combined = rescale_noise_cfg (x0_audio_combined , x0_audio_text , guidance_rescale = audio_guidance_rescale )
20022036
20032037 noise_pred_audio = convert_to_vel (audio_latents_step , x0_audio_combined , sigma_t )
2004- jax .lax .cond (is_first_step , lambda : print_stats_jit ("noise_pred_audio_after_guidance" , noise_pred_audio ), lambda : None )
20052038
20062039 # ... (Standard CFG paths can be added here, but for brevity and since LTX2.3 runs with STG this handles the core logic)
20072040 elif do_cfg :
0 commit comments