Skip to content

Commit 18f5d63

Browse files
committed
debug
1 parent 8e53bc8 commit 18f5d63

3 files changed

Lines changed: 40 additions & 13 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,13 +521,11 @@ def __call__(
521521
attn_output = self.attention_op.apply_attention(query=query, key=key, value=value, attention_mask=attention_mask)
522522

523523
if perturbation_mask is not None:
524-
print("DEBUG: Applying perturbation mask")
525524
# value is [B, S, InnerDim]
526525
# attn_output is [B, S, InnerDim]
527526
attn_output = value + perturbation_mask * (attn_output - value)
528527

529528
if getattr(self, "to_gate_logits", None) is not None:
530-
print("DEBUG: Applying gated attention")
531529
gate_logits = self.to_gate_logits(hidden_states)
532530
b, s, _ = attn_output.shape
533531
attn_output = attn_output.reshape(b, s, self.heads, self.dim_head)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,6 @@ def __init__(
746746

747747
# 2. Prompt embeddings
748748
if self.use_prompt_embeddings:
749-
print("DEBUG: Initializing caption projection (LTX-2.0 path)")
750749
self.caption_projection = NNXPixArtAlphaTextProjection(
751750
rngs=rngs,
752751
in_features=self.caption_channels,
@@ -766,7 +765,6 @@ def __init__(
766765
self.audio_caption_projection = None
767766

768767
if self.cross_attn_mod:
769-
print("DEBUG: Initializing prompt_adaln (LTX-2.3 path)")
770768
self.prompt_adaln = LTX2AdaLayerNormSingle(
771769
rngs=rngs,
772770
embedding_dim=inner_dim,
@@ -1105,7 +1103,6 @@ def __call__(
11051103
audio_embedded_timestep = audio_embedded_timestep.reshape(batch_size, -1, audio_embedded_timestep.shape[-1])
11061104

11071105
if self.cross_attn_mod and sigma is not None:
1108-
print("DEBUG: Executing prompt_adaln (LTX-2.3 path)")
11091106
audio_sigma = audio_sigma if audio_sigma is not None else sigma
11101107
temb_prompt, _ = self.prompt_adaln(
11111108
sigma.flatten(),
@@ -1122,7 +1119,6 @@ def __call__(
11221119
temb_prompt_audio = None
11231120

11241121
if use_cross_timestep:
1125-
print("DEBUG: Using cross timestep (LTX-2.3 path)")
11261122
assert sigma is not None and audio_sigma is not None, "sigma and audio_sigma must be provided when use_cross_timestep is True"
11271123
video_ca_timestep = audio_sigma.flatten()
11281124
audio_ca_timestep = sigma.flatten()

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,14 @@ def __call__(
13981398
prompt_attention_mask_jax = jnp.concatenate([negative_prompt_attention_mask_jax, prompt_attention_mask_jax], axis=0)
13991399
latents_jax = jnp.concatenate([latents_jax] * 2, axis=0)
14001400
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
1401+
def _print_stats_gemma(name, tensor):
1402+
print(
1403+
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1404+
)
1405+
if do_cfg and do_stg:
1406+
_print_stats_gemma("text_encoder_output_flattened", prompt_embeds_jax[:2])
1407+
else:
1408+
_print_stats_gemma("text_encoder_output_flattened", prompt_embeds_jax)
14011409

14021410
if hasattr(self, "mesh") and self.mesh is not None:
14031411
data_sharding_3d = NamedSharding(self.mesh, P())
@@ -1444,19 +1452,31 @@ def __call__(
14441452
audio_embeds_sharded = jax.device_put(audio_embeds, spec)
14451453
def _print_stats(name, tensor):
14461454
print(
1447-
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 4)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1455+
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
14481456
)
14491457
print(f"WEIGHT DEBUG: block 0 to_q kernel mean: {float(self.transformer.transformer_blocks.attn1.to_q.kernel.value[0].mean()):.6f}")
1450-
_print_stats("video_embeds", video_embeds)
1451-
_print_stats("audio_embeds", audio_embeds)
1458+
if do_cfg and do_stg:
1459+
_print_stats("video_text_embedding", video_embeds[:2])
1460+
_print_stats("audio_text_embedding", audio_embeds[:2])
1461+
else:
1462+
_print_stats("video_text_embedding", video_embeds)
1463+
_print_stats("audio_text_embedding", audio_embeds)
14521464

14531465
timesteps_jax = jnp.array(timesteps, dtype=jnp.float32)
14541466

14551467
diffusion_loop_start = time.time()
14561468
scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True)
14571469

1458-
_print_stats("latents_jax_before_loop", latents_jax)
1459-
_print_stats("audio_latents_jax_before_loop", audio_latents_jax)
1470+
if do_cfg and do_stg:
1471+
_print_stats("latents_jax_before_loop", latents_jax[:batch_size])
1472+
_print_stats("audio_latents_jax_before_loop", audio_latents_jax[:batch_size])
1473+
elif do_cfg:
1474+
_print_stats("latents_jax_before_loop", latents_jax[:batch_size])
1475+
_print_stats("audio_latents_jax_before_loop", audio_latents_jax[:batch_size])
1476+
else:
1477+
_print_stats("latents_jax_before_loop", latents_jax)
1478+
_print_stats("audio_latents_jax_before_loop", audio_latents_jax)
1479+
14601480
if scan_diffusion_loop:
14611481
latents_jax, audio_latents_jax = run_diffusion_loop(
14621482
graphdef,
@@ -1948,6 +1968,21 @@ def scan_body(carry, inputs):
19481968

19491969
is_first_step = (t == timesteps_jax[0])
19501970

1971+
def print_raw_stats():
1972+
print_stats_jit("noise_pred_video_raw", noise_pred)
1973+
print_stats_jit("noise_pred_audio_raw", noise_pred_audio)
1974+
if do_cfg:
1975+
uncond_v = noise_pred[:batch_size]
1976+
cond_v = noise_pred[batch_size : 2 * batch_size]
1977+
uncond_a = noise_pred_audio[:batch_size]
1978+
cond_a = noise_pred_audio[batch_size : 2 * batch_size]
1979+
print_stats_jit("noise_pred_video_raw_uncond", uncond_v)
1980+
print_stats_jit("noise_pred_video_raw_cond", cond_v)
1981+
print_stats_jit("noise_pred_audio_raw_uncond", uncond_a)
1982+
print_stats_jit("noise_pred_audio_raw_cond", cond_a)
1983+
1984+
jax.lax.cond(is_first_step, print_raw_stats, lambda: None)
1985+
19511986

19521987
# Extract latents_step based on stacking strategy
19531988
if do_cfg and do_stg:
@@ -1981,7 +2016,6 @@ def scan_body(carry, inputs):
19812016
x0_combined = rescale_noise_cfg(x0_combined, x0_text, guidance_rescale=guidance_rescale)
19822017

19832018
noise_pred = convert_to_vel(latents_step, x0_combined, sigma_t)
1984-
jax.lax.cond(is_first_step, lambda: print_stats_jit("noise_pred_video_after_guidance", noise_pred), lambda: None)
19852019

19862020
# Audio guidance
19872021
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb, noise_pred_audio_isolated = jnp.split(noise_pred_audio, 4, axis=0)
@@ -2001,7 +2035,6 @@ def scan_body(carry, inputs):
20012035
x0_audio_combined = rescale_noise_cfg(x0_audio_combined, x0_audio_text, guidance_rescale=audio_guidance_rescale)
20022036

20032037
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined, sigma_t)
2004-
jax.lax.cond(is_first_step, lambda: print_stats_jit("noise_pred_audio_after_guidance", noise_pred_audio), lambda: None)
20052038

20062039
# ... (Standard CFG paths can be added here, but for brevity and since LTX2.3 runs with STG this handles the core logic)
20072040
elif do_cfg:

0 commit comments

Comments
 (0)