Skip to content

Commit 86ade9e

Browse files
CP Tests batching using subprocess worker pool (#2993)
* Batch CP attention tests via a persistent NCCL pool The existing test path spawns one torchrun per parametrized case, paying NCCL init + CUDA context + Python startup on every call. With ~hundreds of cases the launch overhead dominates wall time and was a primary driver of the L3 timeout that prior batching PRs worked around. This change replaces the per-case subprocess with one long-lived torchrun per (world_size). NCCL is initialized once at session start and reused across cases. Pytest sends one JSON request per case over rank-0 stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers (ok, error) from every rank, and writes one JSON response on rank-0 stdout. run_attention_with_cp.py is left almost untouched; a new NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and dist.destroy_process_group() calls so the function reuses the pool's main PG instead of creating its own. The per-case cp_comm_group (and a2a+p2p sub-groups) are explicitly destroyed at function exit to prevent communicator leakage across cases. The PoolWorker class adds two pieces of error recovery that the prior subprocess-per-case design got for free: a select-based per-call timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on worker death or timeout. A test-level exception is reported as an AssertionError and the pool keeps running for the next case. Two pool sizes are needed because cp_comm_type='a2a+p2p' requires world_size=4 and the others use world_size=2; you can't resize an active PG. Pools are spawned lazily so a 2-GPU-only run never pays the 4-GPU init. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Reset FP8 state and barrier between pool cases Two resilience fixes carried over from the existing batching PR (sudhakars/cp_test_batching_pr) without which the pool will cascade-fail FP8 tests and silently propagate NCCL desync. 1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state (recipe handles, autocast counters) lives in module-level globals. Reusing one Python process across cases otherwise carries that state forward. The prior batching PR landed an explicit fix for the same issue ("Fix FP8 cascade failures") after observing real test failures from this. 2. dist.barrier() after each case. If one rank's case errored before its last collective, the others can be stuck waiting on a comm that will never complete. The barrier here surfaces that immediately as a timeout in this case rather than letting the corruption leak into the next case's collectives. Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the top of each call. run_dpa_with_cp already sets them unconditionally so this is defensive, but cheap insurance against future variants that might not. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Deep-copy ModelConfig in run_dpa_with_cp The model_configs_{flash,fused}_attn dicts are module-level and shared across pool cases. The THD branch below rewrites config.attn_mask_type in place (causal -> padding_causal, no_mask -> padding). With the persistent-pool runner, the next case looking up the same model key gets the mutated config and fails the "causal or no_mask only" assert. Caught at benchmark time on cp_2_0 + thd, identical to the cascade the existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the same way in commit 6355f62. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Skip deterministic configs incompatible with FusedAttention Mirrors the two pre-emptive skips on the PR-batching branch: * non-vanilla softmax with FusedAttention is not deterministic * post_scale_bias with requires_grad is not deterministic Without these skips, the corresponding configs propagate into the pool worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside run_dpa_with_cp instead of being marked SKIPPED. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Reseed RNG between pool cases; reset before, not after The pool worker reused RNG state across cases, which produced small numerical drift on some non-FP8 fused-attention configs (cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the single-shot worker. Matches the per-case startup of the single-shot worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at the start of every case, alongside the existing FP8 / env / cache resets. Moved the reset call from the post-case finally block to the start of _run_one so the first case is also seeded consistently with subsequent cases. Otherwise the first case would inherit the process-default RNG and only the second-and-later cases would be deterministic. Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Robustify pool: capture worker stderr, tighten timeout, add timing knob Three changes that bring the pool's failure semantics on par with the per-batch torchrun approach in PR #2965 and remove a couple of footguns: 1. Capture pool-worker stderr into a ring buffer and attach the tail to crash-path AssertionErrors. Equivalent in spirit to PR #2965's run_distributed() — CI JUnit XML now shows the actual cause (NCCL error, Python traceback, OOM) inline with the failing test, instead of just "pool worker died mid-request" / "timed out". A daemon drainer thread reads stderr line-by-line into a deque(maxlen=200) and also echoes to sys.stderr so pytest's per-test capture still gets every line. Maximum buffered footprint ~40 KB. 2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s). 90 s gives ~6x headroom over the worst observed case while still detecting a genuine hang within ~1.5 min instead of ~10 min. Env var still overrides for slower machines or expanded test matrices. 3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr on rank 0 only. Grep-friendly; lets future tuning recalibrate the timeout against the observed distribution. Off by default so normal runs stay quiet. Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True, with no perf regression vs the un-patched 256 s. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address PR review: NCCL leak, stdout protocol, Windows note Three fixes responding to #2993 review comments: P1: NCCL communicator leak on exception (run_attention_with_cp.py) run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups) near the top, but the destroy_process_group() calls ran only on the success path at the end of the function. Any exception in between (tensor assertion, OOM, NCCL error) skipped the cleanup, leaking communicators in pool mode. Long sessions with repeated failures could exhaust NCCL internal tracking. Wrap the test work in try/finally so the destroy logic always runs. Initialise cp_comm_sub_groups = [] unconditionally so the finally block is safe even when cp_comm_type != "a2a+p2p" (or when an assert fires before the populate loop). Each destroy is itself try/except so a destroy failure on one group doesn't leak the others. P2: stdout protocol can be corrupted by interleaved chatter torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0 print, NCCL debug line, or torchrun status output interleaves with the JSON response and breaks json.loads, killing the pool with a misleading "json decode error". Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py and have PoolWorker.submit() scan stdout for sentinel-prefixed lines, echoing non-protocol lines to stderr for visibility. Bounded scan (MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent. P2 (doc): select.select on a pipe fd is Linux/macOS only Added a short comment noting Windows portability. CP attention tests run on Linux GPU hosts; this is a documentation issue, not a real bug. Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True (was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line overhead at ~600 ms/case, within noise). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is written inside `with torch.cuda.stream(flash_attn_streams[i])`. For i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default stream. Later, at loop iteration i=2, the code reads max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])` which runs on the default stream. Without an explicit wait_stream, this is a read-after-write race across streams. The post-loop `current_stream().wait_stream(cp_stream)` is too late — the race has already fired. The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path), streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exposed by PR #2993's pool design — prior workloads shape stream state differently, the read can fire before the write completes, and max_logit ends up with stale values in some heads (~0.3 abs diff, 3/12 elements wrong on the H100 matrix). Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])` before the torch.maximum read. No-op when the streams are identical (i=1 case, where flash_attn_streams[0] is current_stream), only fires when reading from cp_stream (i=2 case). Validated: 8xH100, test_essential=False, 348 passed / 0 failed in 27m 10s (was 323 passed + 5 failed at this commit's parent, all 5 failing on cp_comm_type=all_gather with mismatched max_logit). The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now pass under the pool — confirming the race was the sole root cause. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Address PR review (R2): drop dead code in pool worker and PoolWorker Line-level cleanups from the second reviewer pass on PR #2993. Each item is dead/redundant; none changes behaviour. Full-matrix test_essential=False on 8xH100 still passes 348/0 in 26m 23s after these. run_attention_with_cp_pool.py: - Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top and pops the FP8 ones itself. The pop loop was defensive against a hypothetical "future caller that doesn't re-set them" that doesn't exist. - Drop gc.collect() after torch.cuda.empty_cache(). The cases create no Python reference cycles between iterations and empty_cache only frees CUDA blocks PyTorch already considers free; the combination was no-op here. - Drop dist.barrier() after dist.gather_object(). gather_object is itself a collective synchronization point — if every rank reaches it, none is ahead. The "surface a wedged communicator here" comment was wishful: a wedged communicator would already wedge the gather. test_attention_with_cp.py (PoolWorker): - Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the unreachable post-loop "1000+ lines" branch. select()'s deadline already bounds the loop; the line-count cap was redundant and the over-limit branch was unreachable in practice. - Inline _stderr_tail() into _diag(). Single caller, single use. - Drop the _stderr_thread attribute. The drainer is daemon and self-terminates when the pipe closes; we never read the field anywhere, so initialising and nulling it was bookkeeping for no reason. - Drop the dead assert in submit() — _ensure_alive() on the prior line already guarantees proc/stdin/stdout exist. Deferred to a follow-up: - L8 (drop try/except around dist.destroy_process_group). Real semantic change: hides errors that occur when a previous test wedged the communicator. Worth doing but needs its own validation. - R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var), M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Address PR review (items 2+3): reuse CP groups across pool cases world_size and the rank set don't change for the lifetime of one pool, so recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms of NCCL setup each. Pre-create them once in the pool worker (new helper _create_cp_comm_groups), stash on the run_attention_with_cp module via module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them once at shutdown. Also move per-case dist.new_group() calls inside the try/finally in run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population otherwise leaks every communicator created before the failure. The finally now only destroys groups we created locally (cp_comm_group / sub_groups populated in the else-branch), leaving pool-owned groups alone for reuse. cyanguwa's review feedback on PR #2993. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Flatten try/finally wrap in run_dpa_with_cp The Round-1 P1 NCCL-communicator-leak fix (e162a9e) wrapped the ~540-line body of run_dpa_with_cp in try/finally. The wrap itself was tiny but it re-indented every line of the body by one level, inflating the PR diff of run_attention_with_cp.py to ~1000 lines against origin/main. Items 2+3 (d15bfce) since made the wrap unnecessary: - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by the pool worker (which destroys them once at pool shutdown). run_dpa_with_cp neither creates nor destroys them, so an in-body exception can't leak communicators. - In single-shot mode, groups are still created locally, but the subprocess exits at function return; NCCL releases everything at process teardown, so a stray exception leaks communicators only for the milliseconds before the process dies — a bounded one-off cost, not the unbounded accumulation that Round-1 flagged for pool mode. Removing the wrap drops the run_attention_with_cp.py diff against origin/main from ~1000 lines to ~120 lines without changing observable behaviour. Smoke-tested: 4 representative cases pass. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Set test_essential=True to match shipping default Round-3 review (greptile, discussion_r3250016711) flagged that the working tree had test_essential=False — i.e. the full ~328-config matrix instead of the ~38-config essential subset that the rest of the CI matrix expects. Flipping back to True so CI doesn't regress baseline on the known H1-style cascade configs that only appear in the full matrix. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Retry once on pool-infrastructure failures with stderr-logged flake trace The pool worker subprocess can die mid-case due to async NCCL aborts or flaky 4-GPU collective state that doesn't reproduce on a fresh pool. Without retry, these manifest as one-off CI failures attributable to infrastructure, not the PR's content. Add a single-attempt retry around PoolWorker.submit() that fires only on infrastructure failure modes (pool-worker-died, timeout, broken-pipe-pre-send). Test-assertion failures from the worker (resp["error"]) carry full per-rank tracebacks and propagate without retry — so a real bug still surfaces as FAILED. Visibility: every retry attempt writes a [POOL-RETRY] line to stderr. pytest captures per-test stderr and writes it into JUnit <testcase>/<system-err>. A flaky test will appear as PASSED in the case row but with a [POOL-RETRY] line in <system-err> — visible to the reviewer, and queryable by CI dashboards looking for flake patterns (e.g. "same test_id retries across multiple CI runs"). If both attempts die, a [POOL-RETRY-FAIL] line is also logged with the first error's headline, then the second attempt's full traceback propagates as the test failure. Smoke-tested: 3 representative cases (p2p, a2a flash; p2p fused) still PASS in 19 s. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [PyTorch] Pool: redirect non-rank-0 stdout to /dev/null; drop sentinel Replaces the [CP_POOL_RESP] sentinel-prefix protocol with a stronger fix at the source: on rank>0, close stdout at the fd level via dup2 to /dev/null at worker startup. Catches both Python `print` writes and C-level (NCCL, libc, etc.) writes that the sentinel could only mitigate by scanning + skipping non-protocol lines. With non-rank-0 stdout silenced, rank 0's JSON line is the only thing that reaches the parent's pipe, so PoolWorker._submit_once collapses from a sentinel-scanning while loop to a single select + readline + json.loads. Closes follow-up M2 from the PR description; addresses greptile's review comment on stdout pollution. Validated on 8xH100 with the test_essential=True flash-attn pool path (9 passed / 55 skipped / 0 failed in 56s; no JSONDecodeError, no protocol corruption). Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Address PR review (R3): backend-cache, pool isolation, group-kill, decode-safety - Invalidate DotProductAttention._attention_backends between pool cases so per-case NVTE_FLASH_ATTN/NVTE_FUSED_ATTN toggles take effect instead of reusing the previous case's resolved backend. - torch.cuda.empty_cache() after each case so a 2-GPU pool doesn't squat on GPUs that an overlapping 4-GPU pool needs. - PoolWorker subprocess uses start_new_session=True; _kill() uses killpg on the whole process group so torchrun's rank workers don't survive as orphans holding CUDA/NCCL state. - On a failed worker response, kill the pool before raising so half-aborted CUDA/NCCL/FP8 state from a failed case doesn't leak into the next. - Guard json.loads with a try/except + diagnostic so any rank-0 stdout pollution surfaces as a clear test failure rather than a silent protocol desync. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1bd9964 commit 86ade9e

4 files changed

Lines changed: 576 additions & 70 deletions

File tree

tests/pytorch/attention/run_attention_with_cp.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# See LICENSE for license information.
44

5+
import copy
56
import os
67
import sys
78
import logging
@@ -29,6 +30,15 @@
2930
)
3031
from utils import ModelConfig, compare_and_assert
3132

33+
# Pool mode (NVTE_CP_POOL_PG=1) only: shared CP collective groups, created once
34+
# per pool by run_attention_with_cp_pool.main() and reused across every case in
35+
# that pool. world_size and the rank set don't change per case, so re-creating
36+
# these per call would be wasted NCCL setup (~50-100 ms each). Single-shot
37+
# subprocess mode leaves these None / [] and run_dpa_with_cp creates/destroys
38+
# its own groups inline.
39+
_pool_cp_comm_group = None
40+
_pool_cp_comm_sub_groups: list = []
41+
3242
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
3343

3444

@@ -209,10 +219,13 @@ def run_dpa_with_cp(
209219
os.environ["NVTE_FUSED_ATTN"] = "0"
210220
if kernel_backend == "FlashAttention":
211221
os.environ["NVTE_FLASH_ATTN"] = "1"
212-
config = model_configs_flash_attn[model]
222+
# Deep-copy: the module-level dict is shared across pool cases; the
223+
# THD branch below rewrites attn_mask_type in place, which would
224+
# otherwise leak into subsequent cases reusing the same model key.
225+
config = copy.deepcopy(model_configs_flash_attn[model])
213226
if kernel_backend == "FusedAttention":
214227
os.environ["NVTE_FUSED_ATTN"] = "1"
215-
config = model_configs_fused_attn[model]
228+
config = copy.deepcopy(model_configs_fused_attn[model])
216229
assert config.attn_mask_type in [
217230
"causal",
218231
"no_mask",
@@ -226,6 +239,9 @@ def run_dpa_with_cp(
226239
# set up distributed group
227240
rank = int(os.getenv("RANK", "0"))
228241
world_size = int(os.getenv("WORLD_SIZE", "1"))
242+
# When NVTE_CP_POOL_PG=1, the pool runner owns the lifecycle of the main
243+
# process group across many cases; here we only reuse it.
244+
_pool_managed_pg = os.getenv("NVTE_CP_POOL_PG", "0") == "1"
229245
if dist.is_initialized():
230246
world_size = dist.get_world_size()
231247
rank = dist.get_rank()
@@ -234,25 +250,35 @@ def run_dpa_with_cp(
234250
device = rank % device_count
235251
torch.cuda.set_device(device)
236252
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
237-
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
238-
239-
# set up communication group for CP
253+
if not _pool_managed_pg:
254+
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
255+
256+
# Set up communication group for CP. In pool mode, the pool worker has
257+
# already pre-created world-scoped and a2a+p2p sub-groups once and stashed
258+
# them in module-level pointers; we reuse those and the pool destroys them
259+
# at shutdown. In single-shot mode we create them per call and destroy in
260+
# the finally below.
240261
cp_comm_ranks = range(world_size)
241262
assert rank in cp_comm_ranks
242-
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
243-
if cp_comm_type == "a2a+p2p":
244-
assert world_size % 2 == 0, (
245-
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
246-
" = 2."
247-
)
248-
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
249-
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
250-
cp_comm_sub_groups = []
251-
for sub_ranks in cp_comm_sub_ranks:
252-
sub_group = dist.new_group(sub_ranks, backend="nccl")
253-
if rank in sub_ranks:
254-
cp_comm_sub_groups.append(sub_group)
255-
263+
_reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None
264+
cp_comm_group = None
265+
cp_comm_sub_groups: list = []
266+
if _reusing_pool_groups:
267+
cp_comm_group = _pool_cp_comm_group
268+
cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else []
269+
else:
270+
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
271+
if cp_comm_type == "a2a+p2p":
272+
assert world_size % 2 == 0, (
273+
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has"
274+
" cp_size = 2."
275+
)
276+
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
277+
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
278+
for sub_ranks in cp_comm_sub_ranks:
279+
sub_group = dist.new_group(sub_ranks, backend="nccl")
280+
if rank in sub_ranks:
281+
cp_comm_sub_groups.append(sub_group)
256282
if dtype == "fp8":
257283
if scaling_mode == "delayed":
258284
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
@@ -564,7 +590,10 @@ def run_dpa_with_cp(
564590
seq_kv_size = dbias.shape[-1]
565591
# Reshape to split seq_q dimension
566592
dbias = dbias.view(
567-
*shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size
593+
*shape_before_seq,
594+
2 * world_size,
595+
seq_q_size // (2 * world_size),
596+
seq_kv_size,
568597
)
569598
# Index select on the newly created dimension (now at position seq_q_dim)
570599
dbias = dbias.index_select(seq_q_dim, seq_idx)
@@ -754,16 +783,43 @@ def run_dpa_with_cp(
754783
)
755784
elif qkv_format == "thd":
756785
compare_and_assert(
757-
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
786+
t,
787+
tensors_cp[i],
788+
names_no_cp[i],
789+
names_cp[i],
790+
atol,
791+
rtol,
792+
rmse_tol,
793+
is_fp8,
758794
)
759795
else:
760796
compare_and_assert(
761797
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
762798
)
763799
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
764800

765-
# destroy distribution group
766-
dist.destroy_process_group()
801+
# Teardown on the success path. Pool mode: cp_comm_group / cp_comm_sub_groups
802+
# point at pool-shared groups owned by the pool runner (which destroys them
803+
# at pool shutdown), and the main PG is also pool-owned — both branches
804+
# below are no-ops. Single-shot mode: destroy what we created here. If the
805+
# body above raises, we skip this — the subprocess dies at function return
806+
# and NCCL releases the communicators with the process.
807+
if not _reusing_pool_groups:
808+
if cp_comm_group is not None:
809+
try:
810+
dist.destroy_process_group(cp_comm_group)
811+
except Exception:
812+
pass
813+
for g in cp_comm_sub_groups:
814+
try:
815+
dist.destroy_process_group(g)
816+
except Exception:
817+
pass
818+
if not _pool_managed_pg:
819+
try:
820+
dist.destroy_process_group()
821+
except Exception:
822+
pass
767823

768824

769825
def main(**kwargs):
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""
6+
Persistent worker for batched CP attention tests.
7+
8+
Launched ONCE per (pytest session, world_size) by torchrun. All ranks init
9+
NCCL, then enter a dispatch loop:
10+
11+
rank 0:
12+
read one JSON request line from stdin
13+
broadcast it to all ranks
14+
all ranks:
15+
call run_dpa_with_cp(**kwargs) — the same work function the
16+
per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the
17+
function reuses our PG instead of re-initing it
18+
torch.cuda.empty_cache() per case
19+
all ranks gather (ok, error_msg) to rank 0
20+
rank 0:
21+
write one JSON response line to stdout
22+
23+
Protocol (line-delimited JSON over rank-0 stdio):
24+
request : {"op": "run", "kwargs": {...}}
25+
{"op": "shutdown"}
26+
response: {"ok": true}
27+
{"ok": false, "error": "first failing rank's traceback"}
28+
"""
29+
import json
30+
import os
31+
import sys
32+
import time
33+
import traceback
34+
35+
import torch
36+
import torch.distributed as dist
37+
38+
# Make sibling modules importable when launched directly.
39+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
40+
41+
from run_attention_with_cp import run_dpa_with_cp
42+
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
43+
44+
45+
def _recv_request(rank: int) -> dict:
46+
box = [None]
47+
if rank == 0:
48+
line = sys.stdin.readline()
49+
box[0] = {"op": "shutdown"} if not line else json.loads(line)
50+
dist.broadcast_object_list(box, src=0)
51+
return box[0]
52+
53+
54+
def _send_response(rank: int, payload: dict) -> None:
55+
if rank == 0:
56+
sys.stdout.write(json.dumps(payload) + "\n")
57+
sys.stdout.flush()
58+
59+
60+
def _silence_non_rank0_stdout(rank: int) -> None:
61+
"""Redirect non-rank-0 stdout to /dev/null at fd level.
62+
63+
All ranks share rank 0's stdout fd (torchrun inherits it from the launcher),
64+
so Python/library writes on rank>0 would interleave with rank 0's JSON
65+
protocol on the parent's pipe. Closing fd 1 at the OS level on rank>0
66+
catches both Python (``print``) and C-level (NCCL, etc.) writes.
67+
"""
68+
if rank == 0:
69+
return
70+
devnull = os.open(os.devnull, os.O_WRONLY)
71+
os.dup2(devnull, 1)
72+
os.close(devnull)
73+
sys.stdout = open(1, "w", closefd=False)
74+
75+
76+
def _reset_between_cases() -> None:
77+
"""Drop state that would otherwise cascade across cases.
78+
79+
Matches the per-case startup of the single-shot worker
80+
(``_run_single_config`` on the per-case-subprocess branch): identical RNG
81+
seed at the start of every case, FP8 state cleared, allocator clean.
82+
``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN``
83+
unconditionally and pops the other transient env vars itself, so no
84+
explicit pop is needed here.
85+
"""
86+
torch.manual_seed(1234)
87+
torch.cuda.manual_seed(1234)
88+
FP8GlobalStateManager.reset()
89+
torch.cuda.empty_cache()
90+
# Invalidate DPA's module-level backend cache so the per-case
91+
# NVTE_FLASH_ATTN/NVTE_FUSED_ATTN env-var toggle actually takes effect
92+
# instead of reusing the previous case's resolved backend.
93+
try:
94+
from transformer_engine.pytorch.attention.dot_product_attention import dot_product_attention
95+
96+
dot_product_attention._attention_backends["backend_selection_requires_update"] = True
97+
except (ImportError, AttributeError, KeyError):
98+
pass
99+
100+
101+
_case_counter = 0
102+
103+
104+
def _run_one(req: dict, rank: int) -> tuple[bool, str]:
105+
global _case_counter
106+
op = req["op"]
107+
if op != "run":
108+
return False, f"unknown op: {op}"
109+
# Reset BEFORE the case so the first case also starts from a known RNG seed
110+
# and clean FP8 state — same as the single-shot worker's per-process startup.
111+
_reset_between_cases()
112+
t0 = time.monotonic()
113+
ok = True
114+
err = ""
115+
try:
116+
run_dpa_with_cp(**req.get("kwargs", {}))
117+
except Exception:
118+
ok = False
119+
err = f"[Rank {rank}] {traceback.format_exc()}"
120+
wall = time.monotonic() - t0
121+
# Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1.
122+
# Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution.
123+
if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")):
124+
_case_counter += 1
125+
sys.stderr.write(
126+
f"[POOL-TIMING] case_idx={_case_counter} "
127+
f"world_size={int(os.environ.get('WORLD_SIZE', 0))} "
128+
f"wall_s={wall:.3f} ok={ok}\n"
129+
)
130+
sys.stderr.flush()
131+
return ok, err
132+
133+
134+
def _create_cp_comm_groups(rank: int, world_size: int) -> tuple:
135+
"""Pre-create the CP collective groups for this pool.
136+
137+
world_size and the rank set are constant for the lifetime of one pool, so
138+
the world group and the a2a+p2p sub-groups are deterministic. Creating
139+
them once here and reusing them across every case eliminates ~50-100 ms
140+
of NCCL setup per case (cyanguwa's review feedback on PR #2993).
141+
142+
Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is
143+
empty when world_size is too small to support a2a+p2p (needs an even
144+
world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to
145+
such a pool anyway.
146+
"""
147+
world_group = dist.new_group(range(world_size), backend="nccl")
148+
sub_groups: list = []
149+
if world_size >= 4 and world_size % 2 == 0:
150+
# Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along
151+
# axis 0, plus 2 stride-2 groups along axis 1.
152+
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
153+
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
154+
for sub_ranks in cp_comm_sub_ranks:
155+
sub_group = dist.new_group(sub_ranks, backend="nccl")
156+
if rank in sub_ranks:
157+
sub_groups.append(sub_group)
158+
return world_group, sub_groups
159+
160+
161+
def main() -> None:
162+
rank = int(os.environ["RANK"])
163+
world_size = int(os.environ["WORLD_SIZE"])
164+
_silence_non_rank0_stdout(rank)
165+
torch.cuda.set_device(rank % torch.cuda.device_count())
166+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
167+
os.environ["NVTE_CP_POOL_PG"] = "1"
168+
169+
# Stash pool-shared CP groups on the run_attention_with_cp module so
170+
# run_dpa_with_cp can read them per case. Imported here (after the env var
171+
# is set) to keep import-time side effects minimal.
172+
import run_attention_with_cp as _rac
173+
174+
_rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups(
175+
rank, world_size
176+
)
177+
178+
try:
179+
while True:
180+
req = _recv_request(rank)
181+
if req.get("op") == "shutdown":
182+
break
183+
184+
ok, msg = _run_one(req, rank)
185+
186+
gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item]
187+
# gather_object is itself a collective synchronization point — if
188+
# every rank reached it, none is ahead. No extra barrier needed.
189+
dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0)
190+
191+
if rank == 0:
192+
all_ok = all(o for o, _ in gathered)
193+
if all_ok:
194+
_send_response(rank, {"ok": True})
195+
else:
196+
first_err = next(m for o, m in gathered if not o)
197+
_send_response(rank, {"ok": False, "error": first_err})
198+
# Release the allocator cache so this pool doesn't squat on
199+
# GPUs that an overlapping different-world-size pool needs.
200+
torch.cuda.empty_cache()
201+
finally:
202+
# Tear down pool-shared CP groups before the main PG (NCCL requires
203+
# sub-groups to be destroyed first). Each destroy is independently
204+
# guarded so a wedged communicator on one group doesn't leak the rest.
205+
if _rac._pool_cp_comm_group is not None:
206+
try:
207+
dist.destroy_process_group(_rac._pool_cp_comm_group)
208+
except Exception:
209+
pass
210+
for g in _rac._pool_cp_comm_sub_groups:
211+
try:
212+
dist.destroy_process_group(g)
213+
except Exception:
214+
pass
215+
_rac._pool_cp_comm_group = None
216+
_rac._pool_cp_comm_sub_groups = []
217+
dist.destroy_process_group()
218+
219+
220+
if __name__ == "__main__":
221+
main()

0 commit comments

Comments
 (0)