Skip to content

Commit 25ac1cc

Browse files
committed
remove unused function.
1 parent ae72f97 commit 25ac1cc

1 file changed

Lines changed: 0 additions & 19 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -589,25 +589,6 @@ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: in
589589
return attn_mask
590590

591591

592-
def _maybe_unflatten_attention_heads(out: torch.Tensor, reference_q: torch.Tensor) -> torch.Tensor:
593-
"""
594-
Flash Attention 3 (and some hub builds) may return tensors where the head and head-dim axes are packed together.
595-
Use the original query to restore the canonical [B, S, H, D] shape expected by the rest of the codebase.
596-
"""
597-
if reference_q.ndim != 4 or out.ndim != 3:
598-
return out
599-
600-
if out.shape[0] != reference_q.shape[0] or out.shape[1] != reference_q.shape[1]:
601-
return out
602-
603-
num_heads, head_dim = reference_q.shape[-2:]
604-
expected_width = num_heads * head_dim
605-
if out.shape[-1] != expected_width:
606-
return out
607-
608-
return out.reshape(reference_q.shape[0], reference_q.shape[1], num_heads, head_dim)
609-
610-
611592
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
612593
return q_idx >= kv_idx
613594

0 commit comments

Comments
 (0)