@@ -387,7 +387,10 @@ class Attention2DAttention(AttentionBackend):
387387 -----------
388388 Ranks are arranged in a 2-D logical mesh of shape ``[row_size, col_size]``
389389 (total parallelism degree = ``P = row_size * col_size``). Each rank holds a
390- ``[B, S/P, H, D]`` shard of Q, K, and V.
390+ ``[B, S_q/P, H_q, D]`` shard of Q and ``[B, S_kv/P, H_kv, D]`` shards of K and V.
391+ For self-attention ``S_q = S_kv`` and ``H_q = H_kv``; for GQA ``H_kv < H_q``; for
392+ cross-attention ``S_kv`` may differ from ``S_q``. K/V must be sequence-sharded
393+ across the same mesh as Q (not replicated on every rank).
391394
392395 Example for ``row_size=2, col_size=3`` (6 ranks total)::
393396
@@ -401,19 +404,22 @@ class Attention2DAttention(AttentionBackend):
401404 Ranks in the same **column** share a ``col_process_group`` and all-gather K/V.
402405
403406 Architecture:
404- Input: [B, S/P, H, D] (sequence sharded across P = row_size × col_size ranks)
405- Step 1: Q all-gather within row group: [B, S/P, H, D] → [B, S/col_size, H, D]
406- Step 2: K/V fused all-gather within col group [B, S/P, H, D] → [B, S/row_size, H, D]
407- (K and V packed into [2, B, S/P, H, D] before the gather,
407+ Input: Q [B, S_q/P, H_q, D], K/V [B, S_kv/P, H_kv, D]
408+ (sequence sharded across P = row_size × col_size ranks)
409+ Step 1: Q all-gather within row group:
410+ [B, S_q/P, H_q, D] → [B, S_q/row_size, H_q, D]
411+ Step 2: K/V fused all-gather within col group:
412+ [B, S_kv/P, H_kv, D] → [B, S_kv/col_size, H_kv, D]
413+ (K and V packed into [2, B, S_kv/P, H_kv, D] before the gather,
408414 halving NCCL launch overhead vs. two separate collectives)
409415 Step 3: Local attention with inner backend:
410- Q [B, S/col_size, H , D] × K,V [B, S/row_size, H , D]
411- → output [B, S/col_size, H , D] + LSE [B, H, S/col_size ]
416+ Q [B, S_q/row_size, H_q , D] × K,V [B, S_kv/col_size, H_kv , D]
417+ → output [B, S_q/row_size, H_q , D] + LSE [B, H_q, S_q/row_size ]
412418 Step 4: Reduce-scatter output within row group, split into:
413419 all_to_all_single to exchange partial outputs and LSEs, then
414420 LSE-weighted combine via flash_attn_combine
415- → [B, S /P, H , D] (fully reduced, matching input layout)
416- Output: [B, S /P, H , D]
421+ → [B, S_q /P, H_q , D] (fully reduced, matching input Q layout)
422+ Output: [B, S_q /P, H_q , D]
417423
418424 Supported inner backends
419425 ------------------------
@@ -432,6 +438,10 @@ class Attention2DAttention(AttentionBackend):
432438 Constraints
433439 -----------
434440 * Only ``PredefinedAttentionMask.FULL`` (or ``None``) is supported.
441+ * Global ``S_q`` and ``S_kv`` must each be divisible by ``P = row_size × col_size``
442+ so every rank holds an equal local shard.
443+ * Cross-attention requires K/V to be sequence-sharded across the mesh (same as Q),
444+ not replicated on every rank.
435445 * ``flash_attn_combine`` (JIT CUDA kernel) must be importable at
436446 construction time; the constructor raises ``ImportError`` otherwise.
437447 * The ``_combine`` step is wrapped in ``@torch.compiler.disable`` because
@@ -478,6 +488,7 @@ def __init__(
478488 )
479489 self .head_dim = inner_backend .head_dim
480490 self .num_heads = inner_backend .num_heads
491+ self .num_kv_heads = getattr (inner_backend , "num_kv_heads" , self .num_heads )
481492 self ._inner_layout = inner_backend .preferred_layout
482493 if self ._inner_layout not in (AttentionTensorLayout .NHD , AttentionTensorLayout .HND ):
483494 raise NotImplementedError (
@@ -494,44 +505,66 @@ def forward(
494505 """
495506 Forward pass with Attention2D sequence parallelism.
496507
497- q/ k/v: [B, S /P, H , D] each .
508+ q: [B, S_q/P, H_q, D]. k/v: [B, S_kv /P, H_kv , D].
498509 """
499- B , shard_seq , H , D = q .shape
510+ B , shard_seq_q , H_q , D = q .shape
511+ _ , shard_seq_kv , H_kv , D_kv = k .shape
500512 attention_mask = kwargs .get ("attention_mask" , None )
501513
514+ if D_kv != D :
515+ raise ValueError (
516+ f"Attention2DAttention: q head_dim ({ D } ) must match k head_dim ({ D_kv } )."
517+ )
518+ if v .shape != k .shape :
519+ raise ValueError (
520+ f"Attention2DAttention: k and v shapes must match, got k={ k .shape } , v={ v .shape } ."
521+ )
522+ if H_q != self .num_heads :
523+ raise ValueError (
524+ f"Attention2DAttention: q num_heads ({ H_q } ) must match "
525+ f"inner backend num_heads ({ self .num_heads } )."
526+ )
527+ if H_kv != self .num_kv_heads :
528+ raise ValueError (
529+ f"Attention2DAttention: k num_kv_heads ({ H_kv } ) must match "
530+ f"inner backend num_kv_heads ({ self .num_kv_heads } )."
531+ )
532+
502533 if attention_mask is not None and attention_mask != PredefinedAttentionMask .FULL :
503534 raise ValueError (
504535 f"Attention2DAttention only supports FULL attention mask, got { attention_mask } ."
505536 )
506537
507538 if self .row_group_size > 1 :
508539 # All-gather q within row_process_group using a single flat buffer.
509- # [B, S/P, H, D] → [row_group_size, B, S/P, H, D] → [B, S/col_group_size, H, D]
510- q_recv = q .new_empty (self .row_group_size , B , shard_seq , H , D )
540+ # [B, S_q/P, H_q, D] → [row_group_size, B, S_q/P, H_q, D]
541+ # → [B, S_q/row_size, H_q, D]
542+ q_recv = q .new_empty (self .row_group_size , B , shard_seq_q , H_q , D )
511543 torch .distributed .all_gather_into_tensor (
512544 q_recv .view (- 1 ), q .contiguous ().view (- 1 ), group = self .row_process_group
513545 )
514- q = q_recv .permute (1 , 0 , 2 , 3 , 4 ).reshape (B , self .row_group_size * shard_seq , H , D )
546+ q = q_recv .permute (1 , 0 , 2 , 3 , 4 ).reshape (B , self .row_group_size * shard_seq_q , H_q , D )
515547
516548 if self .col_group_size > 1 :
517549 # Fuse K and V into a single all-gather to reduce NCCL launch overhead.
518- # [2, B, S/P, H, D] → [col_group_size, 2, B, S/P, H, D] → split back to K, V
519- kv_send = k .new_empty (2 , B , shard_seq , H , D )
550+ # [2, B, S_kv/P, H_kv, D] → [col_group_size, 2, B, S_kv/P, H_kv, D]
551+ # → [B, S_kv/col_size, H_kv, D]
552+ kv_send = k .new_empty (2 , B , shard_seq_kv , H_kv , D )
520553 kv_send [0 ].copy_ (k )
521554 kv_send [1 ].copy_ (v )
522- kv_recv = k .new_empty (self .col_group_size , 2 , B , shard_seq , H , D )
555+ kv_recv = k .new_empty (self .col_group_size , 2 , B , shard_seq_kv , H_kv , D )
523556 torch .distributed .all_gather_into_tensor (
524557 kv_recv .view (- 1 ), kv_send .view (- 1 ), group = self .col_process_group
525558 )
526559 k = (
527560 kv_recv [:, 0 ]
528561 .permute (1 , 0 , 2 , 3 , 4 )
529- .reshape (B , self .col_group_size * shard_seq , H , D )
562+ .reshape (B , self .col_group_size * shard_seq_kv , H_kv , D )
530563 )
531564 v = (
532565 kv_recv [:, 1 ]
533566 .permute (1 , 0 , 2 , 3 , 4 )
534- .reshape (B , self .col_group_size * shard_seq , H , D )
567+ .reshape (B , self .col_group_size * shard_seq_kv , H_kv , D )
535568 )
536569
537570 seq_len = q .shape [1 ]
0 commit comments