@@ -1189,8 +1189,8 @@ def scan_fn(carry, block_mask_and_id):
11891189 audio_rotary_emb = audio_rotary_emb ,
11901190 ca_video_rotary_emb = video_cross_attn_rotary_emb ,
11911191 ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
1192- a2v_cross_attention_mask = encoder_attention_mask ,
1193- v2a_cross_attention_mask = audio_encoder_attention_mask ,
1192+ a2v_cross_attention_mask = None ,
1193+ v2a_cross_attention_mask = None ,
11941194 perturbation_mask = mask ,
11951195 modality_mask = modality_mask ,
11961196 )
@@ -1235,8 +1235,8 @@ def scan_fn(carry, block_mask_and_id):
12351235 ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
12361236 encoder_attention_mask = encoder_attention_mask ,
12371237 audio_encoder_attention_mask = audio_encoder_attention_mask ,
1238- a2v_cross_attention_mask = encoder_attention_mask ,
1239- v2a_cross_attention_mask = audio_encoder_attention_mask ,
1238+ a2v_cross_attention_mask = None ,
1239+ v2a_cross_attention_mask = None ,
12401240 perturbation_mask = mask ,
12411241 )
12421242
0 commit comments