Skip to content

Commit 54dec4f

Browse files
authored
[None][feat] enable GQA and cross-attention for attn2d (#14961)
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
1 parent 835fd61 commit 54dec4f

6 files changed

Lines changed: 310 additions & 49 deletions

File tree

tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,10 @@ class Attention2DAttention(AttentionBackend):
387387
-----------
388388
Ranks are arranged in a 2-D logical mesh of shape ``[row_size, col_size]``
389389
(total parallelism degree = ``P = row_size * col_size``). Each rank holds a
390-
``[B, S/P, H, D]`` shard of Q, K, and V.
390+
``[B, S_q/P, H_q, D]`` shard of Q and ``[B, S_kv/P, H_kv, D]`` shards of K and V.
391+
For self-attention ``S_q = S_kv`` and ``H_q = H_kv``; for GQA ``H_kv < H_q``; for
392+
cross-attention ``S_kv`` may differ from ``S_q``. K/V must be sequence-sharded
393+
across the same mesh as Q (not replicated on every rank).
391394
392395
Example for ``row_size=2, col_size=3`` (6 ranks total)::
393396
@@ -401,19 +404,22 @@ class Attention2DAttention(AttentionBackend):
401404
Ranks in the same **column** share a ``col_process_group`` and all-gather K/V.
402405
403406
Architecture:
404-
Input: [B, S/P, H, D] (sequence sharded across P = row_size × col_size ranks)
405-
Step 1: Q all-gather within row group: [B, S/P, H, D] → [B, S/col_size, H, D]
406-
Step 2: K/V fused all-gather within col group [B, S/P, H, D] → [B, S/row_size, H, D]
407-
(K and V packed into [2, B, S/P, H, D] before the gather,
407+
Input: Q [B, S_q/P, H_q, D], K/V [B, S_kv/P, H_kv, D]
408+
(sequence sharded across P = row_size × col_size ranks)
409+
Step 1: Q all-gather within row group:
410+
[B, S_q/P, H_q, D] → [B, S_q/row_size, H_q, D]
411+
Step 2: K/V fused all-gather within col group:
412+
[B, S_kv/P, H_kv, D] → [B, S_kv/col_size, H_kv, D]
413+
(K and V packed into [2, B, S_kv/P, H_kv, D] before the gather,
408414
halving NCCL launch overhead vs. two separate collectives)
409415
Step 3: Local attention with inner backend:
410-
Q [B, S/col_size, H, D] × K,V [B, S/row_size, H, D]
411-
→ output [B, S/col_size, H, D] + LSE [B, H, S/col_size]
416+
Q [B, S_q/row_size, H_q, D] × K,V [B, S_kv/col_size, H_kv, D]
417+
→ output [B, S_q/row_size, H_q, D] + LSE [B, H_q, S_q/row_size]
412418
Step 4: Reduce-scatter output within row group, split into:
413419
all_to_all_single to exchange partial outputs and LSEs, then
414420
LSE-weighted combine via flash_attn_combine
415-
→ [B, S/P, H, D] (fully reduced, matching input layout)
416-
Output: [B, S/P, H, D]
421+
→ [B, S_q/P, H_q, D] (fully reduced, matching input Q layout)
422+
Output: [B, S_q/P, H_q, D]
417423
418424
Supported inner backends
419425
------------------------
@@ -432,6 +438,10 @@ class Attention2DAttention(AttentionBackend):
432438
Constraints
433439
-----------
434440
* Only ``PredefinedAttentionMask.FULL`` (or ``None``) is supported.
441+
* Global ``S_q`` and ``S_kv`` must each be divisible by ``P = row_size × col_size``
442+
so every rank holds an equal local shard.
443+
* Cross-attention requires K/V to be sequence-sharded across the mesh (same as Q),
444+
not replicated on every rank.
435445
* ``flash_attn_combine`` (JIT CUDA kernel) must be importable at
436446
construction time; the constructor raises ``ImportError`` otherwise.
437447
* The ``_combine`` step is wrapped in ``@torch.compiler.disable`` because
@@ -478,6 +488,7 @@ def __init__(
478488
)
479489
self.head_dim = inner_backend.head_dim
480490
self.num_heads = inner_backend.num_heads
491+
self.num_kv_heads = getattr(inner_backend, "num_kv_heads", self.num_heads)
481492
self._inner_layout = inner_backend.preferred_layout
482493
if self._inner_layout not in (AttentionTensorLayout.NHD, AttentionTensorLayout.HND):
483494
raise NotImplementedError(
@@ -494,44 +505,66 @@ def forward(
494505
"""
495506
Forward pass with Attention2D sequence parallelism.
496507
497-
q/k/v: [B, S/P, H, D] each.
508+
q: [B, S_q/P, H_q, D]. k/v: [B, S_kv/P, H_kv, D].
498509
"""
499-
B, shard_seq, H, D = q.shape
510+
B, shard_seq_q, H_q, D = q.shape
511+
_, shard_seq_kv, H_kv, D_kv = k.shape
500512
attention_mask = kwargs.get("attention_mask", None)
501513

514+
if D_kv != D:
515+
raise ValueError(
516+
f"Attention2DAttention: q head_dim ({D}) must match k head_dim ({D_kv})."
517+
)
518+
if v.shape != k.shape:
519+
raise ValueError(
520+
f"Attention2DAttention: k and v shapes must match, got k={k.shape}, v={v.shape}."
521+
)
522+
if H_q != self.num_heads:
523+
raise ValueError(
524+
f"Attention2DAttention: q num_heads ({H_q}) must match "
525+
f"inner backend num_heads ({self.num_heads})."
526+
)
527+
if H_kv != self.num_kv_heads:
528+
raise ValueError(
529+
f"Attention2DAttention: k num_kv_heads ({H_kv}) must match "
530+
f"inner backend num_kv_heads ({self.num_kv_heads})."
531+
)
532+
502533
if attention_mask is not None and attention_mask != PredefinedAttentionMask.FULL:
503534
raise ValueError(
504535
f"Attention2DAttention only supports FULL attention mask, got {attention_mask}."
505536
)
506537

507538
if self.row_group_size > 1:
508539
# All-gather q within row_process_group using a single flat buffer.
509-
# [B, S/P, H, D] → [row_group_size, B, S/P, H, D] → [B, S/col_group_size, H, D]
510-
q_recv = q.new_empty(self.row_group_size, B, shard_seq, H, D)
540+
# [B, S_q/P, H_q, D] → [row_group_size, B, S_q/P, H_q, D]
541+
# → [B, S_q/row_size, H_q, D]
542+
q_recv = q.new_empty(self.row_group_size, B, shard_seq_q, H_q, D)
511543
torch.distributed.all_gather_into_tensor(
512544
q_recv.view(-1), q.contiguous().view(-1), group=self.row_process_group
513545
)
514-
q = q_recv.permute(1, 0, 2, 3, 4).reshape(B, self.row_group_size * shard_seq, H, D)
546+
q = q_recv.permute(1, 0, 2, 3, 4).reshape(B, self.row_group_size * shard_seq_q, H_q, D)
515547

516548
if self.col_group_size > 1:
517549
# Fuse K and V into a single all-gather to reduce NCCL launch overhead.
518-
# [2, B, S/P, H, D] → [col_group_size, 2, B, S/P, H, D] → split back to K, V
519-
kv_send = k.new_empty(2, B, shard_seq, H, D)
550+
# [2, B, S_kv/P, H_kv, D] → [col_group_size, 2, B, S_kv/P, H_kv, D]
551+
# → [B, S_kv/col_size, H_kv, D]
552+
kv_send = k.new_empty(2, B, shard_seq_kv, H_kv, D)
520553
kv_send[0].copy_(k)
521554
kv_send[1].copy_(v)
522-
kv_recv = k.new_empty(self.col_group_size, 2, B, shard_seq, H, D)
555+
kv_recv = k.new_empty(self.col_group_size, 2, B, shard_seq_kv, H_kv, D)
523556
torch.distributed.all_gather_into_tensor(
524557
kv_recv.view(-1), kv_send.view(-1), group=self.col_process_group
525558
)
526559
k = (
527560
kv_recv[:, 0]
528561
.permute(1, 0, 2, 3, 4)
529-
.reshape(B, self.col_group_size * shard_seq, H, D)
562+
.reshape(B, self.col_group_size * shard_seq_kv, H_kv, D)
530563
)
531564
v = (
532565
kv_recv[:, 1]
533566
.permute(1, 0, 2, 3, 4)
534-
.reshape(B, self.col_group_size * shard_seq, H, D)
567+
.reshape(B, self.col_group_size * shard_seq_kv, H_kv, D)
535568
)
536569

537570
seq_len = q.shape[1]

tensorrt_llm/_torch/visual_gen/models/cosmos3/transformer_cosmos3.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def __init__(self, model_config: DiffusionModelConfig):
679679
)
680680
tp_size = vgm.tp_size if vgm else 1
681681
ulysses_size = vgm.ulysses_size if vgm else 1
682-
cp_size = vgm.cp_size if vgm else 1
682+
ring_size = vgm.ring_size if vgm else 1
683683
head_divisibility_factor = tp_size * ulysses_size
684684

685685
if (ulysses_size > 1 or tp_size > 1) and (
@@ -692,10 +692,11 @@ def __init__(self, model_config: DiffusionModelConfig):
692692
f"TP * Ulysses size ({tp_size} * {ulysses_size})"
693693
)
694694

695-
if cp_size > 1:
696-
# Context parallelism is not compatible with Cosmos3 cross-attention: its forward()
697-
# TODO: Re-enable once Ring/Attn2D PRs with cross-attention support have landed.
698-
raise NotImplementedError("Context parallelism is not supported for Cosmos3. ")
695+
if ring_size > 1:
696+
# Ring parallelism is not compatible with Cosmos3 cross-attention.
697+
raise NotImplementedError(
698+
"Ring parallelism is not supported for Cosmos3 cross-attention."
699+
)
699700

700701
self.language_model = Cosmos3LanguageModel(model_config)
701702

tensorrt_llm/_torch/visual_gen/modules/attention.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,10 @@ def __init__(
230230

231231
if enable_sequence_parallel and self.qkv_mode == QKVMode.SEPARATE_QKV and vgm is not None:
232232
ring_size = vgm.ring_size
233-
attn2d_size = vgm.attn2d_row_size * vgm.attn2d_col_size
234-
if ring_size > 1 or attn2d_size > 1:
233+
if ring_size > 1:
235234
raise ValueError(
236-
"SEPARATE_QKV cross-attention does not support Ring or Attention2D "
237-
"sequence parallelism; use enable_sequence_parallel=False or Ulysses-only "
238-
f"(ring_size={ring_size}, attn2d_size={attn2d_size})."
235+
"SEPARATE_QKV cross-attention does not support Ring sequence "
236+
"parallelism; use enable_sequence_parallel=False or Ulysses/Attention2D."
239237
)
240238

241239
self.attn = wrap_parallel_attention(

tests/unittest/_torch/visual_gen/multi_gpu/test_attn2d_attention.py

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,24 @@ class _LSEVanillaAttention(nn.Module):
7070
values are available, as required by Attention2DAttention.
7171
"""
7272

73-
def __init__(self, num_heads: int, head_dim: int):
73+
def __init__(self, num_heads: int, head_dim: int, num_kv_heads: int | None = None):
7474
super().__init__()
7575
self.num_heads = num_heads
76+
self.num_kv_heads = num_kv_heads or num_heads
7677
self.head_dim = head_dim
7778
self.scale = 1.0 / math.sqrt(head_dim)
7879
self._preferred_layout = AttentionTensorLayout.NHD
7980

81+
def _expand_kv_heads(
82+
self, k_t: torch.Tensor, v_t: torch.Tensor
83+
) -> tuple[torch.Tensor, torch.Tensor]:
84+
if self.num_heads == self.num_kv_heads:
85+
return k_t, v_t
86+
repeat_factor = self.num_heads // self.num_kv_heads
87+
k_t = k_t.repeat_interleave(repeat_factor, dim=1)
88+
v_t = v_t.repeat_interleave(repeat_factor, dim=1)
89+
return k_t, v_t
90+
8091
@property
8192
def preferred_layout(self) -> AttentionTensorLayout:
8293
return self._preferred_layout
@@ -93,14 +104,18 @@ def forward(self, q, k, v, batch_size=None, seq_len=None, **kwargs):
93104
q_t = q.transpose(1, 2).float()
94105
k_t = k.transpose(1, 2).float()
95106
v_t = v.transpose(1, 2).float()
96-
out = F.scaled_dot_product_attention(q_t, k_t, v_t, scale=self.scale)
107+
k_t, v_t = self._expand_kv_heads(k_t, v_t)
108+
out = F.scaled_dot_product_attention(
109+
q_t, k_t, v_t, scale=self.scale, enable_gqa=self.num_heads != self.num_kv_heads
110+
)
97111
return out.to(q.dtype).transpose(1, 2).contiguous()
98112

99113
def forward_with_lse(self, q, k, v, batch_size=None, seq_len=None, **kwargs):
100114
"""Return (output [B, S, H, D], lse [B, H, S])."""
101-
q_t = q.transpose(1, 2).float() # [B, H, S_q, D]
102-
k_t = k.transpose(1, 2).float() # [B, H, S_k, D]
103-
v_t = v.transpose(1, 2).float() # [B, H, S_k, D]
115+
q_t = q.transpose(1, 2).float() # [B, H_q, S_q, D]
116+
k_t = k.transpose(1, 2).float() # [B, H_kv, S_k, D]
117+
v_t = v.transpose(1, 2).float() # [B, H_kv, S_k, D]
118+
k_t, v_t = self._expand_kv_heads(k_t, v_t)
104119
scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * self.scale # [B, H, S_q, S_k]
105120
lse = torch.logsumexp(scores, dim=-1) # [B, H, S_q]
106121
attn = torch.softmax(scores, dim=-1)
@@ -385,6 +400,97 @@ def _logic_attn2d_asymmetric_mesh_4x1(rank, world_size):
385400
)
386401

387402

403+
def _logic_attn2d_gqa(rank, world_size):
404+
"""GQA (H_kv < H_q) with equal Q/KV sequence lengths on a 2x2 mesh."""
405+
row_size, col_size = 2, 2
406+
batch, num_heads, num_kv_heads, head_dim = 2, 8, 2, 64
407+
seq_per_rank = 8
408+
seq_full = seq_per_rank * world_size
409+
410+
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
411+
row_pg, col_pg = _make_process_groups(rank, world_size, row_size, col_size)
412+
413+
inner = _LSEVanillaAttention(num_heads=num_heads, head_dim=head_dim, num_kv_heads=num_kv_heads)
414+
try:
415+
attn = Attention2DAttention(inner, row_pg, col_pg)
416+
except ImportError:
417+
pytest.skip("flash_attn_combine JIT kernels not available")
418+
419+
torch.manual_seed(42)
420+
q_full = torch.randn(batch, seq_full, num_heads, head_dim, device=device)
421+
k_full = torch.randn(batch, seq_full, num_kv_heads, head_dim, device=device)
422+
v_full = torch.randn(batch, seq_full, num_kv_heads, head_dim, device=device)
423+
424+
q_shard = q_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
425+
k_shard = k_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
426+
v_shard = v_full[:, rank * seq_per_rank : (rank + 1) * seq_per_rank].contiguous()
427+
428+
attn2d_output = attn(q_shard, k_shard, v_shard, batch_size=batch)
429+
430+
scale = 1.0 / math.sqrt(head_dim)
431+
q_std = q_full.transpose(1, 2).float()
432+
k_std = k_full.transpose(1, 2).float()
433+
v_std = v_full.transpose(1, 2).float()
434+
std_output = F.scaled_dot_product_attention(q_std, k_std, v_std, scale=scale, enable_gqa=True)
435+
std_output = std_output.transpose(1, 2).to(attn2d_output.dtype)
436+
437+
expected_shard = std_output[:, rank * seq_per_rank : (rank + 1) * seq_per_rank]
438+
torch.testing.assert_close(
439+
attn2d_output,
440+
expected_shard,
441+
rtol=1e-3,
442+
atol=1e-3,
443+
msg=f"Rank {rank}: Attention2D GQA output differs from standard attention",
444+
)
445+
446+
447+
def _logic_attn2d_cross_attention(rank, world_size):
448+
"""Cross-attention with different Q/KV lengths and GQA on a 2x2 mesh."""
449+
row_size, col_size = 2, 2
450+
batch, num_heads, num_kv_heads, head_dim = 2, 8, 2, 64
451+
seq_per_rank_q = 8
452+
seq_per_rank_kv = 4
453+
seq_full_q = seq_per_rank_q * world_size
454+
seq_full_kv = seq_per_rank_kv * world_size
455+
456+
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
457+
row_pg, col_pg = _make_process_groups(rank, world_size, row_size, col_size)
458+
459+
inner = _LSEVanillaAttention(num_heads=num_heads, head_dim=head_dim, num_kv_heads=num_kv_heads)
460+
try:
461+
attn = Attention2DAttention(inner, row_pg, col_pg)
462+
except ImportError:
463+
pytest.skip("flash_attn_combine JIT kernels not available")
464+
465+
torch.manual_seed(42)
466+
q_full = torch.randn(batch, seq_full_q, num_heads, head_dim, device=device)
467+
k_full = torch.randn(batch, seq_full_kv, num_kv_heads, head_dim, device=device)
468+
v_full = torch.randn(batch, seq_full_kv, num_kv_heads, head_dim, device=device)
469+
470+
q_shard = q_full[:, rank * seq_per_rank_q : (rank + 1) * seq_per_rank_q].contiguous()
471+
k_shard = k_full[:, rank * seq_per_rank_kv : (rank + 1) * seq_per_rank_kv].contiguous()
472+
v_shard = v_full[:, rank * seq_per_rank_kv : (rank + 1) * seq_per_rank_kv].contiguous()
473+
474+
attn2d_output = attn(q_shard, k_shard, v_shard, batch_size=batch)
475+
assert attn2d_output.shape == q_shard.shape
476+
477+
scale = 1.0 / math.sqrt(head_dim)
478+
q_std = q_full.transpose(1, 2).float()
479+
k_std = k_full.transpose(1, 2).float()
480+
v_std = v_full.transpose(1, 2).float()
481+
std_output = F.scaled_dot_product_attention(q_std, k_std, v_std, scale=scale, enable_gqa=True)
482+
std_output = std_output.transpose(1, 2).to(attn2d_output.dtype)
483+
484+
expected_shard = std_output[:, rank * seq_per_rank_q : (rank + 1) * seq_per_rank_q]
485+
torch.testing.assert_close(
486+
attn2d_output,
487+
expected_shard,
488+
rtol=1e-3,
489+
atol=1e-3,
490+
msg=f"Rank {rank}: Attention2D cross-attention output differs from standard attention",
491+
)
492+
493+
388494
# =============================================================================
389495
# Test classes
390496
# =============================================================================
@@ -422,6 +528,18 @@ def test_attn2d_4x1_mesh(self):
422528
)
423529

424530

531+
class TestAttn2DAttentionGQAAndCrossAttention:
532+
"""Attention2DAttention with GQA and cross-attention."""
533+
534+
def test_attn2d_gqa(self):
535+
"""GQA with H_kv < H_q on a 2x2 mesh."""
536+
run_test_in_distributed(world_size=4, test_fn=_logic_attn2d_gqa, use_cuda=True)
537+
538+
def test_attn2d_cross_attention(self):
539+
"""Cross-attention with different Q/KV lengths and GQA on a 2x2 mesh."""
540+
run_test_in_distributed(world_size=4, test_fn=_logic_attn2d_cross_attention, use_cuda=True)
541+
542+
425543
def _logic_attn2d_fa4_vs_standard(rank, world_size):
426544
"""Attention2DAttention with FlashAttn4 inner backend matches standard SDPA (2x2 mesh).
427545

0 commit comments

Comments
 (0)