@@ -389,9 +389,15 @@ def __call__(
389389 a2v_cross_attention_mask : Optional [jax .Array ] = None ,
390390 v2a_cross_attention_mask : Optional [jax .Array ] = None ,
391391 perturbation_mask : Optional [jax .Array ] = None ,
392+ layer_id : int = 0 ,
392393 ) -> Tuple [jax .Array , jax .Array ]:
393394 batch_size = hidden_states .shape [0 ]
394395
396+ is_layer_0 = (layer_id == 0 )
397+ def _print_stats_layer (name , tensor ):
398+ jax .debug .print ("DEBUG [BLOCK 0] {name} shape: {shape}, mean: {mean}, min: {min}, max: {max}, std: {std}" ,
399+ name = name , shape = tensor .shape , mean = jnp .round (jnp .mean (tensor ), 6 ), min = jnp .round (jnp .min (tensor ), 4 ), max = jnp .round (jnp .max (tensor ), 4 ), std = jnp .round (jnp .std (tensor ), 4 ))
400+
395401 axis_names = nn .logical_to_mesh_axes (("activation_batch" , "activation_length" , "activation_embed" ))
396402 hidden_states = jax .lax .with_sharding_constraint (hidden_states , axis_names )
397403 axis_names_audio = nn .logical_to_mesh_axes (("activation_batch" , None , "activation_embed" ))
@@ -438,6 +444,7 @@ def __call__(
438444 hidden_states = hidden_states + attn_hidden_states * gate_msa
439445
440446 if self .use_audio and audio_hidden_states is not None :
447+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_in" , audio_hidden_states ), lambda : None )
441448 # Calculate Audio AdaLN values
442449 norm_audio_hidden_states = self .audio_norm1 (audio_hidden_states )
443450
@@ -459,6 +466,7 @@ def __call__(
459466 audio_gate_q = audio_ada_values [:, :, 8 , :]
460467
461468 norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa ) + audio_shift_msa
469+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_norm1_out" , norm_audio_hidden_states ), lambda : None )
462470
463471 with jax .named_scope ("Audio Self-Attention" ):
464472 attn_audio_hidden_states = self .audio_attn1 (
@@ -467,7 +475,9 @@ def __call__(
467475 rotary_emb = audio_rotary_emb ,
468476 perturbation_mask = perturbation_mask ,
469477 )
478+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_attn1_out" , attn_audio_hidden_states ), lambda : None )
470479 audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
480+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_attn1_residual" , audio_hidden_states ), lambda : None )
471481
472482 # 2. Video and Audio Cross-Attention with the text embeddings
473483 norm_hidden_states = self .norm2 (hidden_states )
@@ -496,6 +506,7 @@ def __call__(
496506 norm_audio_hidden_states = self .audio_norm2 (audio_hidden_states )
497507 if getattr (self , "cross_attn_mod" , False ):
498508 norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_q ) + audio_shift_q
509+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_norm2_out" , norm_audio_hidden_states ), lambda : None )
499510
500511 if getattr (self , "cross_attn_mod" , False ) and temb_prompt_audio is not None :
501512 audio_prompt_table_reshaped = jnp .expand_dims (self .audio_prompt_scale_shift_table , axis = (0 , 1 ))
@@ -513,11 +524,14 @@ def __call__(
513524 )
514525 if getattr (self , "cross_attn_mod" , False ):
515526 attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_q
527+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_attn2_out" , attn_audio_hidden_states ), lambda : None )
516528 audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
529+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_attn2_residual" , audio_hidden_states ), lambda : None )
517530
518531 # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
519532 norm_hidden_states = self .audio_to_video_norm (hidden_states )
520533 norm_audio_hidden_states = self .video_to_audio_norm (audio_hidden_states )
534+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_v2a_norm_out" , norm_audio_hidden_states ), lambda : None )
521535
522536 # Calculate Cross-Attention Modulation values
523537 # Video
@@ -582,9 +596,11 @@ def __call__(
582596 k_rotary_emb = ca_video_rotary_emb ,
583597 attention_mask = v2a_cross_attention_mask ,
584598 )
599+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_v2a_attn_out" , v2a_attn_hidden_states ), lambda : None )
585600 if modality_mask is not None :
586601 v2a_attn_hidden_states = v2a_attn_hidden_states * modality_mask
587602 audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
603+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_v2a_residual" , audio_hidden_states ), lambda : None )
588604
589605 # 4. Feedforward
590606 norm_hidden_states = self .norm3 (hidden_states )
@@ -595,8 +611,11 @@ def __call__(
595611 if self .use_audio and audio_hidden_states is not None :
596612 norm_audio_hidden_states = self .audio_norm3 (audio_hidden_states )
597613 norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_mlp ) + audio_shift_mlp
614+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_norm3_out" , norm_audio_hidden_states ), lambda : None )
598615 audio_ff_output = self .audio_ff (norm_audio_hidden_states )
616+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_ff_out" , audio_ff_output ), lambda : None )
599617 audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp
618+ jax .lax .cond (is_layer_0 , lambda : _print_stats_layer ("audio_block_out" , audio_hidden_states ), lambda : None )
600619
601620 return hidden_states , audio_hidden_states
602621
@@ -1193,6 +1212,7 @@ def scan_fn(carry, block_mask_and_id):
11931212 v2a_cross_attention_mask = None ,
11941213 perturbation_mask = mask ,
11951214 modality_mask = modality_mask ,
1215+ layer_id = layer_id ,
11961216 )
11971217 return (
11981218 hidden_states_out .astype (hidden_states .dtype ),
@@ -1238,6 +1258,7 @@ def scan_fn(carry, block_mask_and_id):
12381258 a2v_cross_attention_mask = None ,
12391259 v2a_cross_attention_mask = None ,
12401260 perturbation_mask = mask ,
1261+ layer_id = i ,
12411262 )
12421263
12431264 # 6. Output layers
0 commit comments