Skip to content

Commit ec14f54

Browse files
committed
debug
1 parent 459c82e commit ec14f54

2 files changed

Lines changed: 5 additions & 7 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,8 +1193,8 @@ def scan_fn(carry, block_mask_and_id):
11931193
audio_rotary_emb=audio_rotary_emb,
11941194
ca_video_rotary_emb=video_cross_attn_rotary_emb,
11951195
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1196-
a2v_cross_attention_mask=None,
1197-
v2a_cross_attention_mask=None,
1196+
a2v_cross_attention_mask=encoder_attention_mask,
1197+
v2a_cross_attention_mask=audio_encoder_attention_mask,
11981198
perturbation_mask=mask,
11991199
modality_mask=modality_mask,
12001200
)

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -833,9 +833,6 @@ def _print_stats(name, tensor):
833833
print(
834834
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 4)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
835835
)
836-
_print_stats("text_encoder_output_layer_0", prompt_embeds_list[0])
837-
_print_stats("text_encoder_output_layer_last", prompt_embeds_list[-1])
838-
839836
prompt_embeds = prompt_embeds_list
840837
del text_encoder_hidden_states # Free PyTorch tensor memory
841838

@@ -1949,8 +1946,7 @@ def scan_body(carry, inputs):
19491946
)
19501947

19511948
is_first_step = (t == timesteps_jax[0])
1952-
jax.lax.cond(is_first_step, lambda: print_stats_jit("noise_pred", noise_pred), lambda: None)
1953-
jax.lax.cond(is_first_step, lambda: print_stats_jit("noise_pred_audio", noise_pred_audio), lambda: None)
1949+
19541950

19551951
# Extract latents_step based on stacking strategy
19561952
if do_cfg and do_stg:
@@ -1984,6 +1980,7 @@ def scan_body(carry, inputs):
19841980
x0_combined = rescale_noise_cfg(x0_combined, x0_text, guidance_rescale=guidance_rescale)
19851981

19861982
noise_pred = convert_to_vel(latents_step, x0_combined, sigma_t)
1983+
jax.lax.cond(is_first_step, lambda: print_stats_jit("noise_pred_video_after_guidance", noise_pred), lambda: None)
19871984

19881985
# Audio guidance
19891986
noise_pred_audio_uncond, noise_pred_audio_text, noise_pred_audio_perturb, noise_pred_audio_isolated = jnp.split(noise_pred_audio, 4, axis=0)
@@ -2003,6 +2000,7 @@ def scan_body(carry, inputs):
20032000
x0_audio_combined = rescale_noise_cfg(x0_audio_combined, x0_audio_text, guidance_rescale=audio_guidance_rescale)
20042001

20052002
noise_pred_audio = convert_to_vel(audio_latents_step, x0_audio_combined, sigma_t)
2003+
jax.lax.cond(is_first_step, lambda: print_stats_jit("noise_pred_audio_after_guidance", noise_pred_audio), lambda: None)
20062004

20072005
# ... (Standard CFG paths can be added here, but for brevity and since LTX2.3 runs with STG this handles the core logic)
20082006
elif do_cfg:

0 commit comments

Comments
 (0)