[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829sudhakarsingh27 wants to merge 46 commits into
Conversation
… cu_seqlens - Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing - Use padded cu_seqlens_kv for K/V reordering (ensures divisibility) - Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature - Compute per-step Q and KV cu_seqlens correctly from actual seqlens - Support non-causal attention (all KV visible) - Zero-initialize out/dq for THD to avoid garbage in padding regions - Save per-step cu_seqlens in ctx for backward (avoid recomputation) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded The interleaved valid mask computation assumed cu_seqlens_q_padded starts at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at a non-zero offset, causing a size mismatch. Use absolute positions from cu_seqlens_q_padded to build the valid mask instead. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1164a15 to
b4db9eb
Compare
for more information, see https://pre-commit.ci
| if qkv_format == "thd": | ||
| # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] | ||
| chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) | ||
| k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( |
There was a problem hiding this comment.
This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…formerEngine into cp_thd_swa_with_ag
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
…formerEngine into cp_thd_swa_with_ag
Greptile SummaryThis PR adds THD (variable-length sequence) format support to Confidence Score: 4/5Safe to merge for non-THD workloads; THD+AllGather has unresolved P1 concerns from prior review rounds (dK/dV zeroing assumption, cache key missing CUDA device) that should be confirmed before enabling in production. Several P1 findings from prior review rounds appear unresolved in the current diff (dK/dV kernel-zeroing assumption undocumented, _thd_reorder_perm_cache key missing device). The cp_stream.wait_stream fix was confirmed applied. New findings in this pass are P2 only (redundant out_f16 allocation, Python loop in max_logit masking). P1 ceiling applies. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — dK/dV backward zeroing assumption and permutation cache key require follow-up before production use of THD+AllGather. Important Files Changed
Sequence DiagramsequenceDiagram
participant Fwd as AllGather Forward
participant AG as gather_along_first_dim
participant Reorder as reorder_thd_to_contiguous
participant Stream0 as current_stream (step 0)
participant Stream1 as cp_stream (step 1)
participant cuDNN as fused_attn_fwd
participant Out as out [t,h,d] (zeros)
Fwd->>AG: AllGather K,V → k_ag [cp*t, h, d]
AG-->>Fwd: k_ag, v_ag (rank-concatenated)
Fwd->>Reorder: reorder k_ag,v_ag using cu_seqlens_kv_padded (P_inv)
Reorder-->>Fwd: k_ag, v_ag (contiguous sequence order)
Fwd->>Stream1: cp_stream.wait_stream(current_stream)
Note over Fwd: Pre-compute per-step cu_seqlens_q / cu_seqlens_q_padded / cu_seqlens_kv
par Step 0 on current_stream
Stream0->>cuDNN: fwd(q_full, k_ag, v_ag, cu_seqlens_q_padded_step0)
cuDNN-->>Stream0: out_per_step[0] [t,h,d]
and Step 1 on cp_stream
Stream1->>cuDNN: fwd(q_full, k_ag, v_ag, cu_seqlens_q_padded_step1)
cuDNN-->>Stream1: out_per_step[1] [t,h,d]
end
Stream0->>Out: copy valid ranges (step 0 offsets) → out[s0:s0+sz0]
Stream1->>Out: copy valid ranges (step 1 offsets) → out[s1:s1+sz1]
Fwd->>Fwd: current_stream.wait_stream(cp_stream)
Fwd->>Fwd: all_reduce max_logit across CP ranks
Reviews (7): Last reviewed commit: "Merge branch 'main' of github.com:NVIDIA..." | Re-trigger Greptile |
| # dK/dV: add full tensor (kernel zeros non-valid positions) | ||
| if i > 1: | ||
| flash_attn_streams[i - 1].wait_event(dkv_update_done) | ||
| dk.add_(dk_per_step[i - 1]) | ||
| dv.add_(dv_per_step[i - 1]) |
There was a problem hiding this comment.
THD backward dK/dV relies on unverified cuDNN zeroing behavior
The comment says "kernel zeros non-valid positions", but this assumption is not documented in the cuDNN/TE spec. The A2A backward for THD uses tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, ...) specifically to handle the valid/padding boundary — a plain add_() was not considered sufficient there. If fused_attn_bwd leaves positions beyond cu_seqlens_kv_per_step[i] uninitialised in its output, both steps contribute garbage at non-overlapping KV ranges, which propagates through reduce_scatter_along_first_dim into the final dK/dV.
Before merging, either confirm (and document) that NVTE_F16_arbitrary_seqlen zeros non-valid dK/dV entries, or add explicit zeroing/use tex.thd_grad_correction if applicable to the contiguous-KV layout.
| # [AG+THD] Is this needed? | ||
| visible_actual = [ | ||
| torch.minimum(actual_seqlens_kv, visible_padded_split) | ||
| for visible_padded_split in visible_padded | ||
| ] |
There was a problem hiding this comment.
Unresolved development comment left in production code
# [AG+THD] Is this needed? reads like an open question from a debug session. The torch.minimum clamp is required: for sequences whose length is not a multiple of 2 * cp_size, padded_chunk_sizes_kv * (chunk_id + 1) can exceed actual_seqlens_kv[b], causing cu_seqlens_kv passed to the kernel to count padding as valid tokens. The comment should be resolved or removed.
| if ctx.qkv_format == "thd": | ||
| cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded | ||
| thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step | ||
| cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 |
…ific helpers The AllGather THD path was not extending KV visibility beyond the causal boundary when window_size had a right component > 0, meaning tokens right of the diagonal were invisible to the kernel. Fix by adding window_size[1] to visible_padded (clamped at actual seqlen) and max_seqlen_kv_. Also rename reorder helpers to backend-neutral names since AllGather now uses them too, and add a clarifying comment for non-causal KV cu_seqlens. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
| pytest.skip( | ||
| "FlashAttention does not support THD padding; use FusedAttention for" | ||
| " THD+all_gather CP." |
There was a problem hiding this comment.
Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)
| s = step_padded[b].item() | ||
| sz = (step_valid[b + 1] - step_valid[b]).item() | ||
| if sz > 0: | ||
| out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) |
There was a problem hiding this comment.
This "for" loop might be too costly. Could this logic be written in another way in Python, or simply in C++/CUDA? Do we have some thd kernels that already do this?
| s = step_padded[b].item() | ||
| sz = (step_valid[b + 1] - step_valid[b]).item() | ||
| if sz > 0: | ||
| dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz]) |
There was a problem hiding this comment.
Same here, regarding the "for" loop.
| pytest.skip("THD format does not support post_scale_bias yet!") | ||
| if qkv_format == "thd": | ||
| if cp_comm_type == "all_gather": | ||
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") |
There was a problem hiding this comment.
A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything.
…formerEngine into cp_thd_swa_with_ag
| elif qkv_format == "thd": | ||
| # Copy valid token ranges from this step's output. | ||
| # Each step writes at different positions (no overlap, no correction needed). | ||
| step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] | ||
| step_valid = thd_cu_seqlens_q_per_step[i - 1] | ||
| batch_size = step_valid.shape[0] - 1 | ||
| for b in range(batch_size): | ||
| s = step_padded[b].item() | ||
| sz = (step_valid[b + 1] - step_valid[b]).item() | ||
| if sz > 0: | ||
| out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) |
There was a problem hiding this comment.
THD forward output copy runs on wrong stream for step 1
For step 1 (i == 2), the with torch.cuda.stream(flash_attn_streams[i - 1]) block streams the copy onto cp_stream. But out_per_step[1] is produced on cp_stream by the step-1 attention kernel, and the copy also runs on cp_stream, so there is no race between the kernel and the copy.
However, the if return_max_logit: block at line 3172–3173 runs outside the with block (i.e. on the default stream) and reads max_logit_per_step[i - 1] — which was produced on cp_stream for step 1. There is no current_stream.wait_stream(cp_stream) before this point (that sync only happens at line 3175, after the loop). As a result the torch.maximum kernel launched at line 3173 on the default stream can race against the step-1 attention kernel still running on cp_stream.
This is a pre-existing pattern, but it is newly exercised by THD+AllGather since this is the first path that has both return_max_logit=True and dual-stream THD step execution. Moving the max_logit merge inside the with torch.cuda.stream(flash_attn_streams[i - 1]): block would eliminate the race.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass` adds cutlass/base_dsl/ to sys.path. That directory contains a utils/ package that shadows tests/pytorch/utils.py, breaking collection of test_attention_with_cp.py with: ImportError: cannot import name 'ModelConfig' from 'utils' Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py is always resolved first, regardless of what FA4 dependencies install. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…to cp_thd_swa_with_ag
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…to cp_thd_swa_with_ag
…s its a known cudnn issue Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
for more information, see https://pre-commit.ci
…ransformerEngine into flash_attn_pad_bw_seqs
…utation cache
The reorder_seq_chunks_{before,after}_a2a_*_thd functions used a Python
for-loop of torch.arange calls (2*cp_size*batch iterations) to build
index tensors, which dominated wall-clock time at high batch counts.
Replace with the existing thd_get_partitioned_indices CUDA kernel (one
call per CP rank) plus a permutation cache keyed on (cu_seqlens, cp_size).
This collapses thousands of tiny elementwise kernel launches into a
handful of kernel calls on first use, then a dict lookup thereafter.
Rename to reorder_thd_sequences_to_{rank_sharded,contiguous} since these
are used by both A2A and AllGather CP paths, not just A2A.
Measured speedups (cp=2, bf16, H100x2, 50 iters):
cp_thd_2 (B=16): a2a 10.6x, all_gather 4.5x
cp_thd_3 (B=8): a2a 5.3x, all_gather 3.3x
bariamis_8k (B=2): a2a 3.0x, all_gather 2.2x
bariamis_262k (B=2, S=262k): a2a 1.0x, all_gather 12.0x
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…nto cp_thd_swa_with_ag
The THD AllGather restructuring left cu_seqlens_kv_per_step[i] as None for the SBHD/BSHD path, causing a TypeError in fused_attn_fwd which expects a Tensor. Build it from batch_size and the KV slice length, matching what the original prepare_outputs helper provided. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
| v_ag = reorder_thd_sequences_to_contiguous( | ||
| v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size | ||
| ) |
There was a problem hiding this comment.
Missing
cp_stream.wait_stream in THD AllGather forward
The non-THD path correctly calls cp_stream.wait_stream(torch.cuda.current_stream()) at line 3213 after preparing k_ag/v_ag, ensuring the cp_stream (used for step 1) sees the AllGathered and reordered tensors. The THD branch skips this synchronization, so when step 1 (i=1) launches on cp_stream, it may read partially-overwritten or stale k_ag/v_ag data that reorder_thd_sequences_to_contiguous hasn't finished writing on the current stream.
v_ag = reorder_thd_sequences_to_contiguous(
v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size
)
+ cp_stream.wait_stream(torch.cuda.current_stream())
else:The backward already handles this correctly — ctx.cp_stream.wait_stream(torch.cuda.current_stream()) at line 3782 is unconditional (outside the if ctx.qkv_format == "thd": block). The forward needs the same treatment.
There was a problem hiding this comment.
Incorporated. Check again
The non-THD branch of AttnFuncWithCPAndKVAllGather.forward calls cp_stream.wait_stream(current_stream) after preparing k_ag/v_ag, but the THD branch (added by this PR) was missing that synchronization. Step 1 of the per-step attention loop launches on cp_stream and reads k_ag/v_ag, so without the wait it can race against the AllGather + reorder_thd_sequences_to_contiguous writes still in flight on the current stream. The race is masked at low cp_size because the reorder finishes before cp_stream actually starts reading. At cp=8 with the largest THD config (bucket128k: B=3, S_max=131072) the reorder takes long enough to outlast cp_stream's launch, and the resulting partially-written k_ag/v_ag feeds garbage values into cuDNN, surfacing async as cudaErrorIllegalInstruction. Move the wait_stream call out of the non-THD else-branch so both paths get the synchronization unconditionally. Verified that bucket128k all_gather @ cp=8 now runs cleanly (360.72 ms/iter). Identified by Greptile review (P1, discussion_r3170978192). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…formerEngine into cp_thd_swa_with_ag
…_attn_pad_bw_seqs
…ransformerEngine into flash_attn_pad_bw_seqs
…nto cp_thd_swa_with_ag
| global _thd_reorder_perm_cache | ||
| key = (tuple(cu_seqlens.tolist()), cp_size) | ||
| if key not in _thd_reorder_perm_cache: | ||
| total_tokens = int(cu_seqlens[-1].item()) | ||
| P = torch.cat( | ||
| [ | ||
| tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, rank) | ||
| for rank in range(cp_size) | ||
| ] | ||
| ) | ||
| P_inv = torch.empty_like(P) | ||
| P_inv[P.long()] = torch.arange(total_tokens, dtype=P.dtype, device=P.device) | ||
| _thd_reorder_perm_cache[key] = (P, P_inv) | ||
| return _thd_reorder_perm_cache[key] |
There was a problem hiding this comment.
Every other shape/rank cache in this file (e.g. _seq_chunk_ids_cache_for_reordering_before_attn, _seq_chunk_ids_cache_for_reordering_after_attn, _softmax_offset_chunk_ids_cache) includes the device in its key as (cp_size, device). _thd_reorder_perm_cache only uses (tuple(cu_seqlens.tolist()), cp_size).
In any context where two CUDA devices share the same cu_seqlens shape — multi-GPU unit tests with a single process, or pipeline-parallel+data-parallel setups — the first call populates the cache with tensors on device A; a subsequent call from device B hits the same key and gets those tensors back, then passes them to x.index_select(seq_dim, P_inv) where x lives on device B, producing a cross-device RuntimeError at runtime.
Description
Add THD (variable-length sequence) format support to
AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single[t, h, d]tensor tracked bycu_seqlens, which is needed for workloads with heterogeneous sequence lengths.The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors can't be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the cuDNN kernel each step, with per-step
cu_seqlens_q_paddedvalues directing the kernel to read the correct chunk. This avoids tensor slicing entirely and leverages cuDNN's back-padding convention (valid tokens at the beginning of each padded allocation).Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
cu_seqlens_q_paddedselects which chunk the kernel reads from the full Q tensor, instead of slicing Q per stepreorder_seq_chunks_*_thdhelpers (originally for A2A) to reorder all-gathered KV into contiguous per-sequence ordercu_seqlens_q_paddedin the valid-token mask (step 1's padded offsets don't start at 0)Checklist: