@@ -589,6 +589,25 @@ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: in
589589 return attn_mask
590590
591591
592+ def _maybe_unflatten_attention_heads (out : torch .Tensor , reference_q : torch .Tensor ) -> torch .Tensor :
593+ """
594+ Flash Attention 3 (and some hub builds) may return tensors where the head and head-dim axes are packed together.
595+ Use the original query to restore the canonical [B, S, H, D] shape expected by the rest of the codebase.
596+ """
597+ if reference_q .ndim != 4 or out .ndim != 3 :
598+ return out
599+
600+ if out .shape [0 ] != reference_q .shape [0 ] or out .shape [1 ] != reference_q .shape [1 ]:
601+ return out
602+
603+ num_heads , head_dim = reference_q .shape [- 2 :]
604+ expected_width = num_heads * head_dim
605+ if out .shape [- 1 ] != expected_width :
606+ return out
607+
608+ return out .reshape (reference_q .shape [0 ], reference_q .shape [1 ], num_heads , head_dim )
609+
610+
592611def _flex_attention_causal_mask_mod (batch_idx , head_idx , q_idx , kv_idx ):
593612 return q_idx >= kv_idx
594613
@@ -1533,6 +1552,7 @@ def _flash_attention_3(
15331552 softmax_scale = scale ,
15341553 causal = is_causal ,
15351554 )
1555+ out = _maybe_unflatten_attention_heads (out , query )
15361556 return (out , lse ) if return_lse else out
15371557
15381558
@@ -1577,7 +1597,9 @@ def _flash_attention_3_hub(
15771597 )
15781598 # When `return_attn_probs` is True, the above returns a tuple of
15791599 # actual outputs and lse.
1580- return (out [0 ], out [1 ]) if return_attn_probs else out
1600+ if return_attn_probs :
1601+ return (_maybe_unflatten_attention_heads (out [0 ], query ), out [1 ])
1602+ return _maybe_unflatten_attention_heads (out , query )
15811603
15821604
15831605@_AttentionBackendRegistry .register (
@@ -1630,6 +1652,7 @@ def _flash_varlen_attention_3_hub(
16301652 causal = is_causal ,
16311653 )
16321654 out = out .unflatten (0 , (batch_size , - 1 ))
1655+ out = _maybe_unflatten_attention_heads (out , query )
16331656
16341657 return (out , lse ) if return_lse else out
16351658
0 commit comments