@@ -1450,6 +1450,20 @@ def _print_stats(name, tensor):
14501450 audio_embeds_sharded = jax .device_put (audio_embeds , spec )
14511451 print (f"WEIGHT DEBUG: block 0 to_q kernel mean: { float (self .transformer .transformer_blocks .attn1 .to_q .kernel .value [0 ].mean ()):.6f} " )
14521452
1453+ # Video-to-Audio Attention weights
1454+ v2a = self .transformer .transformer_blocks .video_to_audio_attn
1455+ print (f"WEIGHT DEBUG: block 0 v2a to_q mean: { float (v2a .to_q .kernel .value [0 ].mean ()):.6f} , std: { float (v2a .to_q .kernel .value [0 ].std ()):.6f} " )
1456+ print (f"WEIGHT DEBUG: block 0 v2a to_k mean: { float (v2a .to_k .kernel .value [0 ].mean ()):.6f} , std: { float (v2a .to_k .kernel .value [0 ].std ()):.6f} " )
1457+ print (f"WEIGHT DEBUG: block 0 v2a to_v mean: { float (v2a .to_v .kernel .value [0 ].mean ()):.6f} , std: { float (v2a .to_v .kernel .value [0 ].std ()):.6f} " )
1458+ print (f"WEIGHT DEBUG: block 0 v2a to_out mean: { float (v2a .to_out .kernel .value [0 ].mean ()):.6f} , std: { float (v2a .to_out .kernel .value [0 ].std ()):.6f} " )
1459+
1460+ # Audio-to-Video Attention weights
1461+ a2v = self .transformer .transformer_blocks .audio_to_video_attn
1462+ print (f"WEIGHT DEBUG: block 0 a2v to_q mean: { float (a2v .to_q .kernel .value [0 ].mean ()):.6f} , std: { float (a2v .to_q .kernel .value [0 ].std ()):.6f} " )
1463+ print (f"WEIGHT DEBUG: block 0 a2v to_k mean: { float (a2v .to_k .kernel .value [0 ].mean ()):.6f} , std: { float (a2v .to_k .kernel .value [0 ].std ()):.6f} " )
1464+ print (f"WEIGHT DEBUG: block 0 a2v to_v mean: { float (a2v .to_v .kernel .value [0 ].mean ()):.6f} , std: { float (a2v .to_v .kernel .value [0 ].std ()):.6f} " )
1465+ print (f"WEIGHT DEBUG: block 0 a2v to_out mean: { float (a2v .to_out .kernel .value [0 ].mean ()):.6f} , std: { float (a2v .to_out .kernel .value [0 ].std ()):.6f} " )
1466+
14531467 timesteps_jax = jnp .array (timesteps , dtype = jnp .float32 )
14541468
14551469 diffusion_loop_start = time .time ()
0 commit comments