We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6004e92 commit 5a448f3Copy full SHA for 5a448f3
1 file changed
src/maxdiffusion/models/attention_flax.py
@@ -1193,6 +1193,11 @@ def __call__(
1193
1194
is_i2v_cross_attention = self.added_kv_proj_dim is not None and not is_self_attention
1195
1196
+ # For T2V self-attention and cross-attention, we skip passing the mask
1197
+ # to avoid overhead, as it should be all 1s for unpadded sequences.
1198
+ if not is_i2v_cross_attention:
1199
+ encoder_attention_mask = None
1200
+
1201
if not is_i2v_cross_attention:
1202
with jax.named_scope("query_proj"):
1203
query_proj = self.query(hidden_states)
0 commit comments