@@ -1399,13 +1399,25 @@ def __call__(
13991399 latents_jax = jnp .concatenate ([latents_jax ] * 2 , axis = 0 )
14001400 audio_latents_jax = jnp .concatenate ([audio_latents_jax ] * 2 , axis = 0 )
14011401 def _print_stats_gemma (name , tensor ):
1402- print (
1403- f"DEBUG { name } shape: { tensor .shape } , mean: { jnp .round (jnp .mean (tensor ), 6 )} , min: { jnp .round (jnp .min (tensor ), 4 )} , max: { jnp .round (jnp .max (tensor ), 4 )} , std: { jnp .round (jnp .std (tensor ), 4 )} "
1404- )
1402+ if isinstance (tensor , list ):
1403+ for idx , t in enumerate (tensor ):
1404+ print (
1405+ f"DEBUG { name } _{ idx } shape: { t .shape } , mean: { jnp .round (jnp .mean (t ), 6 )} , min: { jnp .round (jnp .min (t ), 4 )} , max: { jnp .round (jnp .max (t ), 4 )} , std: { jnp .round (jnp .std (t ), 4 )} "
1406+ )
1407+ else :
1408+ print (
1409+ f"DEBUG { name } shape: { tensor .shape } , mean: { jnp .round (jnp .mean (tensor ), 6 )} , min: { jnp .round (jnp .min (tensor ), 4 )} , max: { jnp .round (jnp .max (tensor ), 4 )} , std: { jnp .round (jnp .std (tensor ), 4 )} "
1410+ )
1411+
14051412 if do_cfg and do_stg :
1406- _print_stats_gemma ("text_encoder_output_flattened" , prompt_embeds_jax [:2 ])
1413+ if isinstance (prompt_embeds_jax , list ):
1414+ prompt_embeds_to_print = [x [:2 ] for x in prompt_embeds_jax ]
1415+ else :
1416+ prompt_embeds_to_print = prompt_embeds_jax [:2 ]
14071417 else :
1408- _print_stats_gemma ("text_encoder_output_flattened" , prompt_embeds_jax )
1418+ prompt_embeds_to_print = prompt_embeds_jax
1419+
1420+ _print_stats_gemma ("text_encoder_output_flattened" , prompt_embeds_to_print )
14091421
14101422 if hasattr (self , "mesh" ) and self .mesh is not None :
14111423 data_sharding_3d = NamedSharding (self .mesh , P ())
0 commit comments