Skip to content

Commit 06fed53

Browse files
committed
fix dual stream cross attention masking bug
1 parent fc9c405 commit 06fed53

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,8 +1189,8 @@ def scan_fn(carry, block_mask_and_id):
11891189
audio_rotary_emb=audio_rotary_emb,
11901190
ca_video_rotary_emb=video_cross_attn_rotary_emb,
11911191
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
1192-
a2v_cross_attention_mask=encoder_attention_mask,
1193-
v2a_cross_attention_mask=audio_encoder_attention_mask,
1192+
a2v_cross_attention_mask=None,
1193+
v2a_cross_attention_mask=None,
11941194
perturbation_mask=mask,
11951195
modality_mask=modality_mask,
11961196
)
@@ -1235,8 +1235,8 @@ def scan_fn(carry, block_mask_and_id):
12351235
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
12361236
encoder_attention_mask=encoder_attention_mask,
12371237
audio_encoder_attention_mask=audio_encoder_attention_mask,
1238-
a2v_cross_attention_mask=encoder_attention_mask,
1239-
v2a_cross_attention_mask=audio_encoder_attention_mask,
1238+
a2v_cross_attention_mask=None,
1239+
v2a_cross_attention_mask=None,
12401240
perturbation_mask=mask,
12411241
)
12421242

0 commit comments

Comments
 (0)