Skip to content

Commit 195e52c

Browse files
committed
maybe unflatten heads.
1 parent 965f60e commit 195e52c

1 file changed

Lines changed: 24 additions & 1 deletion

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,25 @@ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: in
589589
return attn_mask
590590

591591

592+
def _maybe_unflatten_attention_heads(out: torch.Tensor, reference_q: torch.Tensor) -> torch.Tensor:
593+
"""
594+
Flash Attention 3 (and some hub builds) may return tensors where the head and head-dim axes are packed together.
595+
Use the original query to restore the canonical [B, S, H, D] shape expected by the rest of the codebase.
596+
"""
597+
if reference_q.ndim != 4 or out.ndim != 3:
598+
return out
599+
600+
if out.shape[0] != reference_q.shape[0] or out.shape[1] != reference_q.shape[1]:
601+
return out
602+
603+
num_heads, head_dim = reference_q.shape[-2:]
604+
expected_width = num_heads * head_dim
605+
if out.shape[-1] != expected_width:
606+
return out
607+
608+
return out.reshape(reference_q.shape[0], reference_q.shape[1], num_heads, head_dim)
609+
610+
592611
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
593612
return q_idx >= kv_idx
594613

@@ -1533,6 +1552,7 @@ def _flash_attention_3(
15331552
softmax_scale=scale,
15341553
causal=is_causal,
15351554
)
1555+
out = _maybe_unflatten_attention_heads(out, query)
15361556
return (out, lse) if return_lse else out
15371557

15381558

@@ -1577,7 +1597,9 @@ def _flash_attention_3_hub(
15771597
)
15781598
# When `return_attn_probs` is True, the above returns a tuple of
15791599
# actual outputs and lse.
1580-
return (out[0], out[1]) if return_attn_probs else out
1600+
if return_attn_probs:
1601+
return (_maybe_unflatten_attention_heads(out[0], query), out[1])
1602+
return _maybe_unflatten_attention_heads(out, query)
15811603

15821604

15831605
@_AttentionBackendRegistry.register(
@@ -1630,6 +1652,7 @@ def _flash_varlen_attention_3_hub(
16301652
causal=is_causal,
16311653
)
16321654
out = out.unflatten(0, (batch_size, -1))
1655+
out = _maybe_unflatten_attention_heads(out, query)
16331656

16341657
return (out, lse) if return_lse else out
16351658

0 commit comments

Comments
 (0)