|
1 | 1 | # SPDX-License-Identifier: MIT |
2 | 2 | # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. |
3 | 3 |
|
| 4 | +"""All-gather correctness: eager and under HIP-graph capture, identical and |
| 5 | +varying inputs, with and without the trailing cross-rank barrier. |
| 6 | +
|
| 7 | +The eager + barrier path (the original coverage) hides a cross-rank |
| 8 | +write-visibility issue that only surfaces when the trailing barrier is dropped |
| 9 | +AND the input changes between back-to-back ops — the regime cudagraph capture |
| 10 | +forces on vLLM (the host barrier is illegal under capture, so async_op=True, and |
| 11 | +the captured step replays every token with fresh activations). `mode` separates |
| 12 | +the cause from the trigger: |
| 13 | + eager_barrier : eager, async_op=False — trailing ctx.barrier() (correct baseline) |
| 14 | + eager_nobarrier : eager, async_op=True — no barrier, no graph (isolates the barrier) |
| 15 | + graph : capture + replay, async_op=True (the vLLM regime) |
| 16 | +`vary=False` replays identical input (a stale read looks correct); `vary=True` |
| 17 | +feeds fresh input each step. impl: torch (known-good control via torch.distributed), |
| 18 | +triton and gluon (the two iris backends, selected by config.use_gluon). |
4 | 19 | """ |
5 | | -Test suite for all-gather collective operation using Gluon. |
6 | | -""" |
7 | | - |
8 | | -import os |
9 | 20 |
|
10 | 21 | import pytest |
11 | 22 | import torch |
12 | 23 | import torch.distributed as dist |
13 | 24 |
|
14 | | -# Try to import Gluon, skip tests if not available |
15 | 25 | try: |
16 | 26 | import iris |
17 | 27 | from iris.ccl import Config |
|
22 | 32 | GLUON_AVAILABLE = False |
23 | 33 |
|
24 | 34 |
|
| 35 | +NUM_REPLAYS = 200 |
| 36 | + |
| 37 | + |
| 38 | +def _all_gather(impl, src, stage_buf, result, shmem, config, async_op): |
| 39 | + """Stage src into the input buffer, then all-gather. Module-level (no closure |
| 40 | + over shmem) so the test can ``del shmem`` for IPC cleanup.""" |
| 41 | + stage_buf.copy_(src) |
| 42 | + if impl == "torch": |
| 43 | + dist.all_gather_into_tensor(result, stage_buf) |
| 44 | + else: |
| 45 | + shmem.ccl.all_gather(result, stage_buf, config=config, async_op=async_op) |
| 46 | + |
| 47 | + |
| 48 | +def _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n): |
| 49 | + """Resolve impl -> (stage_buf, result, config) in one place: torch uses plain |
| 50 | + device tensors and no config; the iris backends use symmetric-heap buffers and |
| 51 | + a use_gluon config. Output is (world_size * M, N) — block r holds rank r's input.""" |
| 52 | + if impl == "torch": |
| 53 | + stage = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}") |
| 54 | + result = torch.empty((world_size * M, N), dtype=dtype, device=f"cuda:{rank}") |
| 55 | + return stage, result, None |
| 56 | + stage = shmem.zeros((M, N), dtype=dtype) |
| 57 | + result = shmem.zeros((world_size * M, N), dtype=dtype) |
| 58 | + config = Config(use_gluon=(impl == "gluon"), block_size_m=block_size_m, block_size_n=block_size_n) |
| 59 | + return stage, result, config |
| 60 | + |
| 61 | + |
25 | 62 | @pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available") |
26 | | -@pytest.mark.parametrize( |
27 | | - "dtype", |
28 | | - [ |
29 | | - torch.float16, |
30 | | - torch.float32, |
31 | | - torch.bfloat16, |
32 | | - ], |
33 | | -) |
| 63 | +@pytest.mark.parametrize("impl", ["torch", "triton", "gluon"]) |
| 64 | +@pytest.mark.parametrize("mode", ["eager_barrier", "eager_nobarrier", "graph"]) |
| 65 | +@pytest.mark.parametrize("vary", [False, True]) |
| 66 | +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) |
34 | 67 | @pytest.mark.parametrize( |
35 | 68 | "M, N, block_size_m, block_size_n", |
36 | | - [ |
37 | | - # block_size_n must be a multiple of (threads_per_warp * num_warps). |
38 | | - # With defaults (threads_per_warp=64, num_warps=4), minimum is 256. |
39 | | - # elems_per_thread = block_size_n / 256: higher = wider vector loads. |
40 | | - (256, 256, 32, 256), # Small: elems_per_thread=1 (scalar loads) |
41 | | - (1024, 512, 32, 512), # Medium: elems_per_thread=2 (dword loads) |
42 | | - (8192, 8192, 32, 1024), # Large: elems_per_thread=4 (dwordx4, optimal) |
43 | | - ], |
| 69 | + [(64, 8192, 32, 1024), (256, 8192, 32, 1024)], |
44 | 70 | ) |
45 | | -def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n): |
46 | | - """Test all-gather functionality using Gluon by comparing against PyTorch's implementation.""" |
47 | | - # Ensure torch.distributed is initialized (should be done by test runner) |
| 71 | +def test_all_gather_gluon(impl, mode, vary, dtype, M, N, block_size_m, block_size_n): |
| 72 | + """Rank r fills its whole input with 1 + r + replay%16 (exact integers), so |
| 73 | + output block r must equal 1 + r + replay%16 — any >=1 mismatch is a real drop. |
| 74 | + Controls: torch and eager_barrier must pass every cell (correct when synced), |
| 75 | + so eager_nobarrier failing isolates the missing cross-rank barrier (no |
| 76 | + cudagraph involved) and graph failing is the vLLM regime. Per-peer-slice fail |
| 77 | + tallies show which peer slices dropped (structured vs scattered).""" |
48 | 78 | if not dist.is_initialized(): |
49 | 79 | pytest.skip("torch.distributed not initialized") |
| 80 | + if impl == "torch" and mode == "eager_nobarrier": |
| 81 | + pytest.skip("torch has no barrier knob; eager_barrier already covers eager torch") |
50 | 82 |
|
51 | | - # Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom |
52 | | - max_ranks = int(os.environ.get("WORLD_SIZE", 8)) |
53 | | - elem_size = torch.tensor([], dtype=dtype).element_size() |
54 | | - needed = (1 + max_ranks) * M * N * elem_size |
55 | | - heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB |
56 | | - shmem = iris.iris(heap_size) |
57 | | - rank = shmem.get_rank() |
58 | | - world_size = shmem.get_num_ranks() |
59 | | - |
60 | | - # Each rank has an M x N input tensor |
61 | | - # Output is (world_size * M, N) - concatenated along dimension 0 |
62 | | - pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") |
63 | | - # Fill with deterministic values for easier debugging |
64 | | - pytorch_input_tensor.fill_(float(rank + 1)) |
65 | | - |
66 | | - # Create output tensor for PyTorch: (world_size * M, N) |
67 | | - pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") |
68 | | - |
69 | | - # Run PyTorch's all_gather_into_tensor to get reference output |
70 | | - shmem.barrier() |
71 | | - dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) |
72 | | - torch.cuda.synchronize() |
| 83 | + # Resolve (impl, mode) up front; the body runs straight-line off these. |
| 84 | + async_op = mode != "eager_barrier" |
| 85 | + capture = mode == "graph" |
73 | 86 |
|
74 | | - # Now set up Iris Gluon all_gather |
75 | | - iris_input_tensor = shmem.zeros((M, N), dtype=dtype) |
76 | | - iris_input_tensor.copy_(pytorch_input_tensor) |
| 87 | + shmem = iris.iris(2**33) # 8 GB |
| 88 | + rank, world_size = shmem.get_rank(), shmem.get_num_ranks() |
| 89 | + src = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}") |
| 90 | + stage_buf, result, config = _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n) |
| 91 | + shmem.barrier() |
77 | 92 |
|
78 | | - iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype) |
| 93 | + def fill_src(replay): |
| 94 | + src.fill_(float(1 + rank + (replay % 16))) |
79 | 95 |
|
80 | | - # Run Iris Gluon all_gather |
81 | | - shmem.barrier() |
82 | | - config = Config(use_gluon=True, block_size_m=block_size_m, block_size_n=block_size_n) |
83 | | - shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) |
| 96 | + # Warmup (runs lazy JIT/setup), then capture the step once if in graph mode. |
| 97 | + fill_src(0) |
| 98 | + _all_gather(impl, src, stage_buf, result, shmem, config, async_op) |
84 | 99 | torch.cuda.synchronize() |
| 100 | + shmem.barrier() |
85 | 101 |
|
86 | | - # Compare results |
87 | | - atol = 1e-3 if dtype == torch.float16 else 1e-5 |
88 | | - max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() |
89 | | - |
| 102 | + graph = None |
| 103 | + if capture: |
| 104 | + stream = torch.cuda.Stream() |
| 105 | + stream.wait_stream(torch.cuda.current_stream()) |
| 106 | + with torch.cuda.stream(stream): |
| 107 | + graph = torch.cuda.CUDAGraph() |
| 108 | + graph.capture_begin() |
| 109 | + _all_gather(impl, src, stage_buf, result, shmem, config, async_op) |
| 110 | + graph.capture_end() |
| 111 | + torch.cuda.current_stream().wait_stream(stream) |
| 112 | + |
| 113 | + atol = 0.5 # exact integer inputs; >=1 mismatch is a real drop |
| 114 | + failures = [] # (step, max|diff|, bad_slices) |
| 115 | + block_fail = [0] * world_size # steps each peer slice dropped |
90 | 116 | try: |
91 | | - assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( |
92 | | - f"Max difference: {max_diff}, expected < {atol}\n" |
93 | | - f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_gather_into_tensor" |
| 117 | + for i in range(NUM_REPLAYS): |
| 118 | + replay = i if vary else 0 |
| 119 | + fill_src(replay) |
| 120 | + if capture: |
| 121 | + graph.replay() |
| 122 | + else: |
| 123 | + _all_gather(impl, src, stage_buf, result, shmem, config, async_op) |
| 124 | + torch.cuda.synchronize() |
| 125 | + diffs = [ |
| 126 | + torch.abs(result[r * M : (r + 1) * M] - float(1 + r + (replay % 16))).max().item() |
| 127 | + for r in range(world_size) |
| 128 | + ] |
| 129 | + bad = [r for r in range(world_size) if diffs[r] > atol] |
| 130 | + for r in bad: |
| 131 | + block_fail[r] += 1 |
| 132 | + if bad: |
| 133 | + failures.append((i, round(max(diffs[r] for r in bad), 4), bad)) |
| 134 | + print( |
| 135 | + f"[rank {rank}] all_gather impl={impl} mode={mode} vary={vary} dtype={dtype} " |
| 136 | + f"{M}x{N}: {NUM_REPLAYS - len(failures)}/{NUM_REPLAYS} ok; " |
| 137 | + f"per-peer-slice fail counts={block_fail}" + (f"; first FAIL={failures[0]}" if failures else ""), |
| 138 | + flush=True, |
| 139 | + ) |
| 140 | + assert not failures, ( |
| 141 | + f"impl={impl} mode={mode} vary={vary} dtype={dtype} {M}x{N}: " |
| 142 | + f"{len(failures)}/{NUM_REPLAYS} steps wrong (first {failures[0]}; per-peer-slice " |
| 143 | + f"fail counts={block_fail}). torch and eager_barrier must pass; eager_nobarrier " |
| 144 | + f"failing isolates the missing cross-rank barrier (no cudagraph); graph is the vLLM regime." |
94 | 145 | ) |
95 | 146 | finally: |
96 | | - # Final barrier to ensure all ranks complete before test cleanup |
97 | | - # This helps with test isolation when running multiple tests |
98 | | - # Note: shmem.barrier() already does cuda.synchronize() |
| 147 | + if graph is not None: |
| 148 | + del graph |
99 | 149 | shmem.barrier() |
100 | | - # Explicitly delete the shmem instance to trigger cleanup |
101 | 150 | del shmem |
102 | | - # Force garbage collection to ensure IPC handles are cleaned up |
103 | 151 | import gc |
104 | 152 |
|
105 | 153 | gc.collect() |
0 commit comments