Skip to content

Commit bca0178

Browse files
committed
tests/ccl: graph-capture + varying-input coverage for gluon/triton all_gather & all_to_all
1 parent 0c3be44 commit bca0178

2 files changed

Lines changed: 213 additions & 139 deletions

File tree

tests/ccl/test_all_gather_gluon.py

Lines changed: 114 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
33

4+
"""All-gather correctness: eager and under HIP-graph capture, identical and
5+
varying inputs, with and without the trailing cross-rank barrier.
6+
7+
The eager + barrier path (the original coverage) hides a cross-rank
8+
write-visibility issue that only surfaces when the trailing barrier is dropped
9+
AND the input changes between back-to-back ops — the regime cudagraph capture
10+
forces on vLLM (the host barrier is illegal under capture, so async_op=True, and
11+
the captured step replays every token with fresh activations). `mode` separates
12+
the cause from the trigger:
13+
eager_barrier : eager, async_op=False — trailing ctx.barrier() (correct baseline)
14+
eager_nobarrier : eager, async_op=True — no barrier, no graph (isolates the barrier)
15+
graph : capture + replay, async_op=True (the vLLM regime)
16+
`vary=False` replays identical input (a stale read looks correct); `vary=True`
17+
feeds fresh input each step. impl: torch (known-good control via torch.distributed),
18+
triton and gluon (the two iris backends, selected by config.use_gluon).
419
"""
5-
Test suite for all-gather collective operation using Gluon.
6-
"""
7-
8-
import os
920

1021
import pytest
1122
import torch
1223
import torch.distributed as dist
1324

14-
# Try to import Gluon, skip tests if not available
1525
try:
1626
import iris
1727
from iris.ccl import Config
@@ -22,84 +32,122 @@
2232
GLUON_AVAILABLE = False
2333

2434

35+
NUM_REPLAYS = 200
36+
37+
38+
def _all_gather(impl, src, stage_buf, result, shmem, config, async_op):
39+
"""Stage src into the input buffer, then all-gather. Module-level (no closure
40+
over shmem) so the test can ``del shmem`` for IPC cleanup."""
41+
stage_buf.copy_(src)
42+
if impl == "torch":
43+
dist.all_gather_into_tensor(result, stage_buf)
44+
else:
45+
shmem.ccl.all_gather(result, stage_buf, config=config, async_op=async_op)
46+
47+
48+
def _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n):
49+
"""Resolve impl -> (stage_buf, result, config) in one place: torch uses plain
50+
device tensors and no config; the iris backends use symmetric-heap buffers and
51+
a use_gluon config. Output is (world_size * M, N) — block r holds rank r's input."""
52+
if impl == "torch":
53+
stage = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}")
54+
result = torch.empty((world_size * M, N), dtype=dtype, device=f"cuda:{rank}")
55+
return stage, result, None
56+
stage = shmem.zeros((M, N), dtype=dtype)
57+
result = shmem.zeros((world_size * M, N), dtype=dtype)
58+
config = Config(use_gluon=(impl == "gluon"), block_size_m=block_size_m, block_size_n=block_size_n)
59+
return stage, result, config
60+
61+
2562
@pytest.mark.skipif(not GLUON_AVAILABLE, reason="Gluon not available")
26-
@pytest.mark.parametrize(
27-
"dtype",
28-
[
29-
torch.float16,
30-
torch.float32,
31-
torch.bfloat16,
32-
],
33-
)
63+
@pytest.mark.parametrize("impl", ["torch", "triton", "gluon"])
64+
@pytest.mark.parametrize("mode", ["eager_barrier", "eager_nobarrier", "graph"])
65+
@pytest.mark.parametrize("vary", [False, True])
66+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
3467
@pytest.mark.parametrize(
3568
"M, N, block_size_m, block_size_n",
36-
[
37-
# block_size_n must be a multiple of (threads_per_warp * num_warps).
38-
# With defaults (threads_per_warp=64, num_warps=4), minimum is 256.
39-
# elems_per_thread = block_size_n / 256: higher = wider vector loads.
40-
(256, 256, 32, 256), # Small: elems_per_thread=1 (scalar loads)
41-
(1024, 512, 32, 512), # Medium: elems_per_thread=2 (dword loads)
42-
(8192, 8192, 32, 1024), # Large: elems_per_thread=4 (dwordx4, optimal)
43-
],
69+
[(64, 8192, 32, 1024), (256, 8192, 32, 1024)],
4470
)
45-
def test_all_gather_gluon(dtype, M, N, block_size_m, block_size_n):
46-
"""Test all-gather functionality using Gluon by comparing against PyTorch's implementation."""
47-
# Ensure torch.distributed is initialized (should be done by test runner)
71+
def test_all_gather_gluon(impl, mode, vary, dtype, M, N, block_size_m, block_size_n):
72+
"""Rank r fills its whole input with 1 + r + replay%16 (exact integers), so
73+
output block r must equal 1 + r + replay%16 — any >=1 mismatch is a real drop.
74+
Controls: torch and eager_barrier must pass every cell (correct when synced),
75+
so eager_nobarrier failing isolates the missing cross-rank barrier (no
76+
cudagraph involved) and graph failing is the vLLM regime. Per-peer-slice fail
77+
tallies show which peer slices dropped (structured vs scattered)."""
4878
if not dist.is_initialized():
4979
pytest.skip("torch.distributed not initialized")
80+
if impl == "torch" and mode == "eager_nobarrier":
81+
pytest.skip("torch has no barrier knob; eager_barrier already covers eager torch")
5082

51-
# Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom
52-
max_ranks = int(os.environ.get("WORLD_SIZE", 8))
53-
elem_size = torch.tensor([], dtype=dtype).element_size()
54-
needed = (1 + max_ranks) * M * N * elem_size
55-
heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB
56-
shmem = iris.iris(heap_size)
57-
rank = shmem.get_rank()
58-
world_size = shmem.get_num_ranks()
59-
60-
# Each rank has an M x N input tensor
61-
# Output is (world_size * M, N) - concatenated along dimension 0
62-
pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}")
63-
# Fill with deterministic values for easier debugging
64-
pytorch_input_tensor.fill_(float(rank + 1))
65-
66-
# Create output tensor for PyTorch: (world_size * M, N)
67-
pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}")
68-
69-
# Run PyTorch's all_gather_into_tensor to get reference output
70-
shmem.barrier()
71-
dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor)
72-
torch.cuda.synchronize()
83+
# Resolve (impl, mode) up front; the body runs straight-line off these.
84+
async_op = mode != "eager_barrier"
85+
capture = mode == "graph"
7386

74-
# Now set up Iris Gluon all_gather
75-
iris_input_tensor = shmem.zeros((M, N), dtype=dtype)
76-
iris_input_tensor.copy_(pytorch_input_tensor)
87+
shmem = iris.iris(2**33) # 8 GB
88+
rank, world_size = shmem.get_rank(), shmem.get_num_ranks()
89+
src = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}")
90+
stage_buf, result, config = _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n)
91+
shmem.barrier()
7792

78-
iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype)
93+
def fill_src(replay):
94+
src.fill_(float(1 + rank + (replay % 16)))
7995

80-
# Run Iris Gluon all_gather
81-
shmem.barrier()
82-
config = Config(use_gluon=True, block_size_m=block_size_m, block_size_n=block_size_n)
83-
shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config)
96+
# Warmup (runs lazy JIT/setup), then capture the step once if in graph mode.
97+
fill_src(0)
98+
_all_gather(impl, src, stage_buf, result, shmem, config, async_op)
8499
torch.cuda.synchronize()
100+
shmem.barrier()
85101

86-
# Compare results
87-
atol = 1e-3 if dtype == torch.float16 else 1e-5
88-
max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item()
89-
102+
graph = None
103+
if capture:
104+
stream = torch.cuda.Stream()
105+
stream.wait_stream(torch.cuda.current_stream())
106+
with torch.cuda.stream(stream):
107+
graph = torch.cuda.CUDAGraph()
108+
graph.capture_begin()
109+
_all_gather(impl, src, stage_buf, result, shmem, config, async_op)
110+
graph.capture_end()
111+
torch.cuda.current_stream().wait_stream(stream)
112+
113+
atol = 0.5 # exact integer inputs; >=1 mismatch is a real drop
114+
failures = [] # (step, max|diff|, bad_slices)
115+
block_fail = [0] * world_size # steps each peer slice dropped
90116
try:
91-
assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), (
92-
f"Max difference: {max_diff}, expected < {atol}\n"
93-
f"Rank {rank}: Iris Gluon output doesn't match PyTorch's all_gather_into_tensor"
117+
for i in range(NUM_REPLAYS):
118+
replay = i if vary else 0
119+
fill_src(replay)
120+
if capture:
121+
graph.replay()
122+
else:
123+
_all_gather(impl, src, stage_buf, result, shmem, config, async_op)
124+
torch.cuda.synchronize()
125+
diffs = [
126+
torch.abs(result[r * M : (r + 1) * M] - float(1 + r + (replay % 16))).max().item()
127+
for r in range(world_size)
128+
]
129+
bad = [r for r in range(world_size) if diffs[r] > atol]
130+
for r in bad:
131+
block_fail[r] += 1
132+
if bad:
133+
failures.append((i, round(max(diffs[r] for r in bad), 4), bad))
134+
print(
135+
f"[rank {rank}] all_gather impl={impl} mode={mode} vary={vary} dtype={dtype} "
136+
f"{M}x{N}: {NUM_REPLAYS - len(failures)}/{NUM_REPLAYS} ok; "
137+
f"per-peer-slice fail counts={block_fail}" + (f"; first FAIL={failures[0]}" if failures else ""),
138+
flush=True,
139+
)
140+
assert not failures, (
141+
f"impl={impl} mode={mode} vary={vary} dtype={dtype} {M}x{N}: "
142+
f"{len(failures)}/{NUM_REPLAYS} steps wrong (first {failures[0]}; per-peer-slice "
143+
f"fail counts={block_fail}). torch and eager_barrier must pass; eager_nobarrier "
144+
f"failing isolates the missing cross-rank barrier (no cudagraph); graph is the vLLM regime."
94145
)
95146
finally:
96-
# Final barrier to ensure all ranks complete before test cleanup
97-
# This helps with test isolation when running multiple tests
98-
# Note: shmem.barrier() already does cuda.synchronize()
147+
if graph is not None:
148+
del graph
99149
shmem.barrier()
100-
# Explicitly delete the shmem instance to trigger cleanup
101150
del shmem
102-
# Force garbage collection to ensure IPC handles are cleaned up
103151
import gc
104152

105153
gc.collect()

0 commit comments

Comments
 (0)