Skip to content

Commit 4a3b5a5

Browse files
committed
debug
1 parent 5bab4e5 commit 4a3b5a5

1 file changed

Lines changed: 14 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,20 @@ def _print_stats(name, tensor):
14501450
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
14511451
print(f"WEIGHT DEBUG: block 0 to_q kernel mean: {float(self.transformer.transformer_blocks.attn1.to_q.kernel.value[0].mean()):.6f}")
14521452

1453+
# Video-to-Audio Attention weights
1454+
v2a = self.transformer.transformer_blocks.video_to_audio_attn
1455+
print(f"WEIGHT DEBUG: block 0 v2a to_q mean: {float(v2a.to_q.kernel.value[0].mean()):.6f}, std: {float(v2a.to_q.kernel.value[0].std()):.6f}")
1456+
print(f"WEIGHT DEBUG: block 0 v2a to_k mean: {float(v2a.to_k.kernel.value[0].mean()):.6f}, std: {float(v2a.to_k.kernel.value[0].std()):.6f}")
1457+
print(f"WEIGHT DEBUG: block 0 v2a to_v mean: {float(v2a.to_v.kernel.value[0].mean()):.6f}, std: {float(v2a.to_v.kernel.value[0].std()):.6f}")
1458+
print(f"WEIGHT DEBUG: block 0 v2a to_out mean: {float(v2a.to_out.kernel.value[0].mean()):.6f}, std: {float(v2a.to_out.kernel.value[0].std()):.6f}")
1459+
1460+
# Audio-to-Video Attention weights
1461+
a2v = self.transformer.transformer_blocks.audio_to_video_attn
1462+
print(f"WEIGHT DEBUG: block 0 a2v to_q mean: {float(a2v.to_q.kernel.value[0].mean()):.6f}, std: {float(a2v.to_q.kernel.value[0].std()):.6f}")
1463+
print(f"WEIGHT DEBUG: block 0 a2v to_k mean: {float(a2v.to_k.kernel.value[0].mean()):.6f}, std: {float(a2v.to_k.kernel.value[0].std()):.6f}")
1464+
print(f"WEIGHT DEBUG: block 0 a2v to_v mean: {float(a2v.to_v.kernel.value[0].mean()):.6f}, std: {float(a2v.to_v.kernel.value[0].std()):.6f}")
1465+
print(f"WEIGHT DEBUG: block 0 a2v to_out mean: {float(a2v.to_out.kernel.value[0].mean()):.6f}, std: {float(a2v.to_out.kernel.value[0].std()):.6f}")
1466+
14531467
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
14541468

14551469
diffusion_loop_start = time.time()

0 commit comments

Comments
 (0)