Skip to content

Commit b9d6420

Browse files
authored
Fix AuraFlow attn processors applying norm_added_q to key projection (#13533)
Both AuraFlowAttnProcessor2_0 and FusedAuraFlowAttnProcessor2_0 were calling attn.norm_added_q on encoder_hidden_states_key_proj while guarded by a check on attn.norm_added_k. This applies the query normalization layer to the key, which is a copy-paste error. Consistent with every other attention processor in this file that defines both norm_added_q and norm_added_k (e.g. FluxAttnProcessor, CogVideoXAttnProcessor, HunyuanAttnProcessor), where norm_added_k is applied to the added key projection.
1 parent 3d30b7d commit b9d6420

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,7 +2140,7 @@ def __call__(
21402140
if attn.norm_added_q is not None:
21412141
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
21422142
if attn.norm_added_k is not None:
2143-
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
2143+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
21442144

21452145
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
21462146
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -2237,7 +2237,7 @@ def __call__(
22372237
if attn.norm_added_q is not None:
22382238
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
22392239
if attn.norm_added_k is not None:
2240-
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
2240+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
22412241

22422242
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
22432243
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)

0 commit comments

Comments
 (0)