Skip to content

Commit 5bab4e5

Browse files
committed
debug
1 parent 06fed53 commit 5bab4e5

2 files changed

Lines changed: 28 additions & 52 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,15 @@ def __call__(
389389
a2v_cross_attention_mask: Optional[jax.Array] = None,
390390
v2a_cross_attention_mask: Optional[jax.Array] = None,
391391
perturbation_mask: Optional[jax.Array] = None,
392+
layer_id: int = 0,
392393
) -> Tuple[jax.Array, jax.Array]:
393394
batch_size = hidden_states.shape[0]
394395

396+
is_layer_0 = (layer_id == 0)
397+
def _print_stats_layer(name, tensor):
398+
jax.debug.print("DEBUG [BLOCK 0] {name} shape: {shape}, mean: {mean}, min: {min}, max: {max}, std: {std}",
399+
name=name, shape=tensor.shape, mean=jnp.round(jnp.mean(tensor), 6), min=jnp.round(jnp.min(tensor), 4), max=jnp.round(jnp.max(tensor), 4), std=jnp.round(jnp.std(tensor), 4))
400+
395401
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
396402
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
397403
axis_names_audio = nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))
@@ -438,6 +444,7 @@ def __call__(
438444
hidden_states = hidden_states + attn_hidden_states * gate_msa
439445

440446
if self.use_audio and audio_hidden_states is not None:
447+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_in", audio_hidden_states), lambda: None)
441448
# Calculate Audio AdaLN values
442449
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
443450

@@ -459,6 +466,7 @@ def __call__(
459466
audio_gate_q = audio_ada_values[:, :, 8, :]
460467

461468
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
469+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_norm1_out", norm_audio_hidden_states), lambda: None)
462470

463471
with jax.named_scope("Audio Self-Attention"):
464472
attn_audio_hidden_states = self.audio_attn1(
@@ -467,7 +475,9 @@ def __call__(
467475
rotary_emb=audio_rotary_emb,
468476
perturbation_mask=perturbation_mask,
469477
)
478+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_attn1_out", attn_audio_hidden_states), lambda: None)
470479
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
480+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_attn1_residual", audio_hidden_states), lambda: None)
471481

472482
# 2. Video and Audio Cross-Attention with the text embeddings
473483
norm_hidden_states = self.norm2(hidden_states)
@@ -496,6 +506,7 @@ def __call__(
496506
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
497507
if getattr(self, "cross_attn_mod", False):
498508
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_q) + audio_shift_q
509+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_norm2_out", norm_audio_hidden_states), lambda: None)
499510

500511
if getattr(self, "cross_attn_mod", False) and temb_prompt_audio is not None:
501512
audio_prompt_table_reshaped = jnp.expand_dims(self.audio_prompt_scale_shift_table, axis=(0, 1))
@@ -513,11 +524,14 @@ def __call__(
513524
)
514525
if getattr(self, "cross_attn_mod", False):
515526
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_q
527+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_attn2_out", attn_audio_hidden_states), lambda: None)
516528
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
529+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_attn2_residual", audio_hidden_states), lambda: None)
517530

518531
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
519532
norm_hidden_states = self.audio_to_video_norm(hidden_states)
520533
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
534+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_v2a_norm_out", norm_audio_hidden_states), lambda: None)
521535

522536
# Calculate Cross-Attention Modulation values
523537
# Video
@@ -582,9 +596,11 @@ def __call__(
582596
k_rotary_emb=ca_video_rotary_emb,
583597
attention_mask=v2a_cross_attention_mask,
584598
)
599+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_v2a_attn_out", v2a_attn_hidden_states), lambda: None)
585600
if modality_mask is not None:
586601
v2a_attn_hidden_states = v2a_attn_hidden_states * modality_mask
587602
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
603+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_v2a_residual", audio_hidden_states), lambda: None)
588604

589605
# 4. Feedforward
590606
norm_hidden_states = self.norm3(hidden_states)
@@ -595,8 +611,11 @@ def __call__(
595611
if self.use_audio and audio_hidden_states is not None:
596612
norm_audio_hidden_states = self.audio_norm3(audio_hidden_states)
597613
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_mlp) + audio_shift_mlp
614+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_norm3_out", norm_audio_hidden_states), lambda: None)
598615
audio_ff_output = self.audio_ff(norm_audio_hidden_states)
616+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_ff_out", audio_ff_output), lambda: None)
599617
audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp
618+
jax.lax.cond(is_layer_0, lambda: _print_stats_layer("audio_block_out", audio_hidden_states), lambda: None)
600619

601620
return hidden_states, audio_hidden_states
602621

@@ -1193,6 +1212,7 @@ def scan_fn(carry, block_mask_and_id):
11931212
v2a_cross_attention_mask=None,
11941213
perturbation_mask=mask,
11951214
modality_mask=modality_mask,
1215+
layer_id=layer_id,
11961216
)
11971217
return (
11981218
hidden_states_out.astype(hidden_states.dtype),
@@ -1238,6 +1258,7 @@ def scan_fn(carry, block_mask_and_id):
12381258
a2v_cross_attention_mask=None,
12391259
v2a_cross_attention_mask=None,
12401260
perturbation_mask=mask,
1261+
layer_id=i,
12411262
)
12421263

12431264
# 6. Output layers

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,26 +1398,7 @@ def __call__(
13981398
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
13991399
latents_jax = jnp.concatenate([latents_jax] * 2, axis=0)
14001400
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
1401-
def _print_stats_gemma(name, tensor):
1402-
if isinstance(tensor, list):
1403-
for idx, t in enumerate(tensor):
1404-
print(
1405-
f"DEBUG {name}_{idx} shape: {t.shape}, mean: {jnp.round(jnp.mean(t), 6)}, min: {jnp.round(jnp.min(t), 4)}, max: {jnp.round(jnp.max(t), 4)}, std: {jnp.round(jnp.std(t), 4)}"
1406-
)
1407-
else:
1408-
print(
1409-
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1410-
)
14111401

1412-
if do_cfg and do_stg:
1413-
if isinstance(prompt_embeds_jax, list):
1414-
prompt_embeds_to_print = [x[:2] for x in prompt_embeds_jax]
1415-
else:
1416-
prompt_embeds_to_print = prompt_embeds_jax[:2]
1417-
else:
1418-
prompt_embeds_to_print = prompt_embeds_jax
1419-
1420-
_print_stats_gemma("text_encoder_output_flattened", prompt_embeds_to_print)
14211402

14221403
if hasattr(self, "mesh") and self.mesh is not None:
14231404
data_sharding_3d = NamedSharding(self.mesh, P())
@@ -1454,6 +1435,11 @@ def _print_stats_gemma(name, tensor):
14541435
jax.block_until_ready(video_embeds)
14551436
max_logging.log(f"⏱️ Connectors Time: {time.time() - connectors_start:.4f} seconds")
14561437

1438+
def _print_stats(name, tensor):
1439+
print(
1440+
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1441+
)
1442+
14571443
video_embeds_sharded = video_embeds
14581444
audio_embeds_sharded = audio_embeds
14591445

@@ -1462,32 +1448,14 @@ def _print_stats_gemma(name, tensor):
14621448
spec = NamedSharding(self.mesh, P(*activation_axes))
14631449
video_embeds_sharded = jax.device_put(video_embeds, spec)
14641450
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
1465-
def _print_stats(name, tensor):
1466-
print(
1467-
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1468-
)
14691451
print(f"WEIGHT DEBUG: block 0 to_q kernel mean: {float(self.transformer.transformer_blocks.attn1.to_q.kernel.value[0].mean()):.6f}")
1470-
if do_cfg and do_stg:
1471-
_print_stats("video_text_embedding", video_embeds[:2])
1472-
_print_stats("audio_text_embedding", audio_embeds[:2])
1473-
else:
1474-
_print_stats("video_text_embedding", video_embeds)
1475-
_print_stats("audio_text_embedding", audio_embeds)
14761452

14771453
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
14781454

14791455
diffusion_loop_start = time.time()
14801456
scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True)
14811457

1482-
if do_cfg and do_stg:
1483-
_print_stats("latents_jax_before_loop", latents_jax[:batch_size])
1484-
_print_stats("audio_latents_jax_before_loop", audio_latents_jax[:batch_size])
1485-
elif do_cfg:
1486-
_print_stats("latents_jax_before_loop", latents_jax[:batch_size])
1487-
_print_stats("audio_latents_jax_before_loop", audio_latents_jax[:batch_size])
1488-
else:
1489-
_print_stats("latents_jax_before_loop", latents_jax)
1490-
_print_stats("audio_latents_jax_before_loop", audio_latents_jax)
1458+
14911459

14921460
if scan_diffusion_loop:
14931461
latents_jax, audio_latents_jax = run_diffusion_loop(
@@ -1980,20 +1948,7 @@ def scan_body(carry, inputs):
19801948

19811949
is_first_step = (t == timesteps_jax[0])
19821950

1983-
def print_raw_stats():
1984-
print_stats_jit("noise_pred_video_raw", noise_pred)
1985-
print_stats_jit("noise_pred_audio_raw", noise_pred_audio)
1986-
if do_cfg:
1987-
uncond_v = noise_pred[:batch_size]
1988-
cond_v = noise_pred[batch_size : 2 * batch_size]
1989-
uncond_a = noise_pred_audio[:batch_size]
1990-
cond_a = noise_pred_audio[batch_size : 2 * batch_size]
1991-
print_stats_jit("noise_pred_video_raw_uncond", uncond_v)
1992-
print_stats_jit("noise_pred_video_raw_cond", cond_v)
1993-
print_stats_jit("noise_pred_audio_raw_uncond", uncond_a)
1994-
print_stats_jit("noise_pred_audio_raw_cond", cond_a)
1995-
1996-
jax.lax.cond(is_first_step, print_raw_stats, lambda: None)
1951+
19971952

19981953

19991954
# Extract latents_step based on stacking strategy

0 commit comments

Comments
 (0)