Skip to content

Commit 2c165af

Browse files
committed
step1b: fix ut
1 parent 50c9535 commit 2c165af

5 files changed

Lines changed: 145 additions & 92 deletions

File tree

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ ______________________________________________________________________
6363
1. **每个 Step 必须配套 UT**:不写 UT 不能进下一个 Step。UT 通过 → 才能集成。详见 `SPEC.md` §6 的测试矩阵。
6464
1. **多卡 UT 用 torchrun 跑**:模板见 skill `multi-gpu-test-template`
6565
1. **代码风格检查**:每个 Step 完成后必须运行 `ruff check` 并修复所有问题,然后重新验证测试通过。Ruff 路径:`/root/miniconda3/bin/ruff`
66+
1. **多卡 UT 默认 4/8 GPU**:除非测试目标明确只适合 2 GPU smoke test,否则多卡 UT 必须至少覆盖 4 GPU;当前环境有 8 张 GPU 时,Step 验收必须优先跑 8 GPU。报告测试结果时写明实际 `torchrun --nproc_per_node` 数量和 backend。
6667

6768
______________________________________________________________________
6869

PROGRESS.md

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,9 @@ This file tracks implementation and validation status. `SPEC.md` remains the des
2929
- Single-process Step 1 tests: PASS
3030
- Command used: `pytest source/tests/pt/test_sezm_moe_a2a.py -q`
3131
- Result: 5 tests passed, 3 subtests passed
32-
- Multi-process Step 1 smoke test: PASS
33-
- Runner: Cursor `multi-gpu-tester` subagent
34-
- Command shape: `torchrun --nproc_per_node=2 ... source/tests/pt/test_sezm_moe_a2a_multigpu.py`
35-
- Result: 4 tests passed, no hang
36-
- Multi-process Step 1 4-process test: PASS
37-
- Command shape: `torchrun --nproc_per_node=4 ... source/tests/pt/test_sezm_moe_a2a_multigpu.py`
38-
- Result: 4 tests passed on all ranks, no hang
32+
- Multi-process Step 1 8-rank CUDA/NCCL test: PASS
33+
- Command shape: `torchrun --nproc_per_node=8 ... source/tests/pt/test_sezm_moe_a2a_multigpu.py`
34+
- Result: 6 tests passed on all 8 ranks, no hang
3935
- Step 1 ruff check: PASS
4036
- Command used: `/root/miniconda3/bin/ruff check deepmd/pt/model/descriptor/sezm_nn/moe/a2a_ops.py source/tests/pt/test_sezm_moe_a2a.py source/tests/pt/test_sezm_moe_a2a_multigpu.py`
4137
- DPA3 reference subagent smoke test: PASS
@@ -48,6 +44,7 @@ This file tracks implementation and validation status. `SPEC.md` remains the des
4844
- `pytest` 9.0.3 is installed in `/mnt/data_nas/zhangd/conda_env/torch-modern`.
4945
- `/root/miniconda3/bin/ruff` is available and reports `ruff 0.15.6`.
5046
- Existing Step 1 tests are runnable via `pytest`, `unittest`, and standalone `torchrun`.
47+
- Multi-rank tests use CUDA/NCCL when CUDA is available and fall back to CPU/Gloo only when CUDA is unavailable.
5148

5249
## Not Started
5350

SPEC.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,13 @@ ______________________________________________________________________
673673

674674
每个测试都对应一个 pytest 文件。命名:`test_sezm_moe_<topic>.py``test_sezm_moe_<topic>_multigpu.py`
675675

676+
### 多卡 UT 的 GPU 数量规则
677+
678+
- 多卡 UT 必须与单卡 UT 分开写成独立文件:`test_sezm_moe_<topic>.py``test_sezm_moe_<topic>_multigpu.py`
679+
- 除非测试目标明确只是 2 GPU smoke test,否则多卡 UT 至少覆盖 4 GPU。
680+
- 当前开发环境有 8 张 GPU 时,Step 验收必须优先跑 8 GPU,并在报告中写明 `torchrun --nproc_per_node`、backend(NCCL/Gloo)和通过的 rank 数。
681+
- 对 A2A、梯度同步、checkpoint resharding、二阶导不死锁等跨 rank 行为,2 GPU 结果只能作为 smoke test,不能替代 4/8 GPU 验收。
682+
676683
______________________________________________________________________
677684

678685
## 8. 配置 schema

source/tests/pt/test_sezm_moe_a2a.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ def test_second_backward(self):
9898
"Second-order gradient should contain non-zero values",
9999
)
100100

101-
def test_gradgradcheck_fp64(self):
102-
"""torch.autograd.gradgradcheck should pass in fp64."""
103-
# Use smaller tensors for gradgradcheck (it's expensive)
101+
def test_short_circuit_gradgradcheck_fp64(self):
102+
"""group=None short-circuit should pass gradgradcheck in fp64."""
103+
# This verifies the single-process passthrough path only. The real
104+
# _AllToAllDouble gradgradcheck lives in the multi-GPU test file.
104105
x = torch.randn(6, 4, dtype=torch.float64, requires_grad=True, device="cpu")
105106
send_splits = [2, 2, 2]
106107
recv_splits = [1, 3, 2]

source/tests/pt/test_sezm_moe_a2a_multigpu.py

Lines changed: 129 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Run with:
55
torchrun --nproc_per_node=2 source/tests/pt/test_sezm_moe_a2a_multigpu.py
66
torchrun --nproc_per_node=4 source/tests/pt/test_sezm_moe_a2a_multigpu.py
7+
torchrun --nproc_per_node=8 source/tests/pt/test_sezm_moe_a2a_multigpu.py
78
"""
89

910
import unittest
@@ -19,20 +20,52 @@
1920
def setup_dist():
2021
"""Initialize distributed environment."""
2122
if not dist.is_initialized():
22-
dist.init_process_group(backend="gloo")
23+
backend = "nccl" if torch.cuda.is_available() else "gloo"
24+
dist.init_process_group(backend=backend)
2325
rank = dist.get_rank()
2426
world_size = dist.get_world_size()
25-
# Use CPU for multi-GPU tests (gloo backend)
26-
device = torch.device("cpu")
27+
if torch.cuda.is_available():
28+
torch.cuda.set_device(rank % torch.cuda.device_count())
29+
device = torch.device("cuda", rank % torch.cuda.device_count())
30+
else:
31+
device = torch.device("cpu")
2732
return rank, world_size, device
2833

2934

3035
def cleanup_dist():
3136
"""Clean up distributed environment."""
3237
if dist.is_initialized():
38+
dist.barrier()
3339
dist.destroy_process_group()
3440

3541

42+
def make_cyclic_splits(rank, world_size):
43+
"""Return deterministic asymmetric splits valid for any world size."""
44+
send_splits = [((rank + 2 * peer) % 5) + 1 for peer in range(world_size)]
45+
recv_splits = [((peer + 2 * rank) % 5) + 1 for peer in range(world_size)]
46+
return send_splits, recv_splits
47+
48+
49+
def make_encoded_input(rank, send_splits, device):
50+
"""Build rows whose values encode source rank, target rank, and row id."""
51+
rows = []
52+
for peer, count in enumerate(send_splits):
53+
for row_id in range(count):
54+
rows.append([float(rank), float(peer), float(row_id)])
55+
return torch.tensor(rows, dtype=torch.float64, device=device)
56+
57+
58+
def make_expected_encoded_output(rank, world_size, device):
59+
"""Expected all-to-all output for make_encoded_input and make_cyclic_splits."""
60+
rows = []
61+
for source_rank in range(world_size):
62+
source_send_splits, _ = make_cyclic_splits(source_rank, world_size)
63+
count = source_send_splits[rank]
64+
for row_id in range(count):
65+
rows.append([float(source_rank), float(rank), float(row_id)])
66+
return torch.tensor(rows, dtype=torch.float64, device=device)
67+
68+
3669
class TestAllToAllMultiGPU(unittest.TestCase):
3770
"""Multi-GPU tests for _AllToAllDouble communication primitive."""
3871

@@ -44,44 +77,19 @@ def setUpClass(cls):
4477

4578
@classmethod
4679
def tearDownClass(cls):
47-
"""Clean up distributed environment."""
48-
cleanup_dist()
80+
"""Keep the process group alive until run_tests aggregates results."""
4981

50-
def test_forward_shape(self):
51-
"""Forward pass should produce correct output shape across ranks."""
52-
# Each rank sends different amounts
53-
# Constraint: rank i's send_splits[j] == rank j's recv_splits[i]
54-
if self.world_size == 2:
55-
send_splits = [3, 5] if self.rank == 0 else [2, 6]
56-
recv_splits = [3, 2] if self.rank == 0 else [5, 6]
57-
elif self.world_size == 4:
58-
# Matrix: send[i][j] = recv[j][i]
59-
# rank 0 sends: [2, 3, 1, 4] -> rank 0 recvs: [2, 5, 3, 7]
60-
# rank 1 sends: [5, 2, 4, 3] -> rank 1 recvs: [3, 2, 6, 4]
61-
# rank 2 sends: [3, 6, 1, 2] -> rank 2 recvs: [1, 4, 1, 5]
62-
# rank 3 sends: [7, 4, 5, 1] -> rank 3 recvs: [4, 3, 2, 1]
63-
if self.rank == 0:
64-
send_splits = [2, 3, 1, 4]
65-
recv_splits = [2, 5, 3, 7]
66-
elif self.rank == 1:
67-
send_splits = [5, 2, 4, 3]
68-
recv_splits = [3, 2, 6, 4]
69-
elif self.rank == 2:
70-
send_splits = [3, 6, 1, 2]
71-
recv_splits = [1, 4, 1, 5]
72-
else: # rank 3
73-
send_splits = [7, 4, 5, 1]
74-
recv_splits = [4, 3, 2, 1]
75-
else:
76-
self.skipTest(f"Test not configured for world_size={self.world_size}")
82+
def test_forward_values_and_shape(self):
83+
"""Forward pass should move the correct rows across ranks."""
84+
send_splits, recv_splits = make_cyclic_splits(self.rank, self.world_size)
7785

7886
total_send = sum(send_splits)
7987
total_recv = sum(recv_splits)
8088

81-
x = torch.randn(total_send, 8, device=self.device, requires_grad=True)
89+
x = make_encoded_input(self.rank, send_splits, self.device).requires_grad_(True)
8290
out = all_to_all_differentiable(x, send_splits, recv_splits, self.group)
91+
expected = make_expected_encoded_output(self.rank, self.world_size, self.device)
8392

84-
# Check output shape
8593
self.assertEqual(
8694
out.shape[0],
8795
total_recv,
@@ -92,20 +100,17 @@ def test_forward_shape(self):
92100
x.shape[1:],
93101
f"Rank {self.rank}: trailing dimensions should be preserved",
94102
)
103+
torch.testing.assert_close(out, expected)
95104

96105
def test_backward_no_deadlock(self):
97106
"""Backward pass should not deadlock."""
98-
if self.world_size == 2:
99-
send_splits = [4, 4]
100-
recv_splits = [4, 4]
101-
elif self.world_size == 4:
102-
send_splits = [2, 2, 2, 2]
103-
recv_splits = [2, 2, 2, 2]
104-
else:
105-
self.skipTest(f"Test not configured for world_size={self.world_size}")
107+
send_splits = [2] * self.world_size
108+
recv_splits = [2] * self.world_size
106109

107110
total_send = sum(send_splits)
108-
x = torch.randn(total_send, 8, device=self.device, requires_grad=True)
111+
x = torch.randn(
112+
total_send, 8, device=self.device, dtype=torch.float64, requires_grad=True
113+
)
109114

110115
out = all_to_all_differentiable(x, send_splits, recv_splits, self.group)
111116
loss = (out**2).sum()
@@ -120,17 +125,13 @@ def test_backward_no_deadlock(self):
120125

121126
def test_second_backward_no_deadlock(self):
122127
"""Second backward (create_graph=True) should not deadlock."""
123-
if self.world_size == 2:
124-
send_splits = [3, 3]
125-
recv_splits = [3, 3]
126-
elif self.world_size == 4:
127-
send_splits = [2, 2, 2, 2]
128-
recv_splits = [2, 2, 2, 2]
129-
else:
130-
self.skipTest(f"Test not configured for world_size={self.world_size}")
128+
send_splits = [2] * self.world_size
129+
recv_splits = [2] * self.world_size
131130

132131
total_send = sum(send_splits)
133-
x = torch.randn(total_send, 8, device=self.device, requires_grad=True)
132+
x = torch.randn(
133+
total_send, 8, device=self.device, dtype=torch.float64, requires_grad=True
134+
)
134135

135136
# First forward
136137
out = all_to_all_differentiable(x, send_splits, recv_splits, self.group)
@@ -157,36 +158,19 @@ def test_second_backward_no_deadlock(self):
157158

158159
def test_asymmetric_splits(self):
159160
"""Test with asymmetric send/recv splits across ranks."""
160-
# Constraint: rank i's send_splits[j] == rank j's recv_splits[i]
161-
if self.world_size == 2:
162-
# Rank 0 sends more to rank 1, rank 1 sends more to rank 0
163-
send_splits = [2, 6] if self.rank == 0 else [5, 3]
164-
recv_splits = [2, 5] if self.rank == 0 else [6, 3]
165-
elif self.world_size == 4:
166-
# Matrix: send[i][j] = recv[j][i]
167-
# rank 0 sends: [1, 2, 3, 4] -> rank 0 recvs: [1, 3, 2, 4]
168-
# rank 1 sends: [3, 2, 1, 4] -> rank 1 recvs: [2, 2, 3, 3]
169-
# rank 2 sends: [2, 3, 4, 1] -> rank 2 recvs: [3, 1, 4, 2]
170-
# rank 3 sends: [4, 3, 2, 1] -> rank 3 recvs: [4, 4, 1, 1]
171-
if self.rank == 0:
172-
send_splits = [1, 2, 3, 4]
173-
recv_splits = [1, 3, 2, 4]
174-
elif self.rank == 1:
175-
send_splits = [3, 2, 1, 4]
176-
recv_splits = [2, 2, 3, 3]
177-
elif self.rank == 2:
178-
send_splits = [2, 3, 4, 1]
179-
recv_splits = [3, 1, 4, 2]
180-
else: # rank 3
181-
send_splits = [4, 3, 2, 1]
182-
recv_splits = [4, 4, 1, 1]
183-
else:
184-
self.skipTest(f"Test not configured for world_size={self.world_size}")
161+
send_splits, recv_splits = make_cyclic_splits(self.rank, self.world_size)
162+
self.assertNotEqual(
163+
send_splits,
164+
recv_splits,
165+
f"Rank {self.rank}: split pattern should be asymmetric",
166+
)
185167

186168
total_send = sum(send_splits)
187169
total_recv = sum(recv_splits)
188170

189-
x = torch.randn(total_send, 16, device=self.device, requires_grad=True)
171+
x = torch.randn(
172+
total_send, 16, device=self.device, dtype=torch.float64, requires_grad=True
173+
)
190174
out = all_to_all_differentiable(x, send_splits, recv_splits, self.group)
191175

192176
# Check shape
@@ -198,12 +182,73 @@ def test_asymmetric_splits(self):
198182
loss.backward()
199183
self.assertIsNotNone(x.grad)
200184

185+
def test_three_layer_second_backward_no_deadlock(self):
186+
"""Three chained A2A ops should support second backward."""
187+
send_splits = [1] * self.world_size
188+
recv_splits = [1] * self.world_size
189+
x = torch.randn(
190+
self.world_size,
191+
4,
192+
dtype=torch.float64,
193+
device=self.device,
194+
requires_grad=True,
195+
)
196+
197+
y = x
198+
for _ in range(3):
199+
y = all_to_all_differentiable(y, send_splits, recv_splits, self.group)
200+
201+
loss = (y**2).sum()
202+
(grad_x,) = torch.autograd.grad(loss, x, create_graph=True, retain_graph=True)
203+
(grad_x**2).sum().backward()
204+
self.assertIsNotNone(x.grad, f"Rank {self.rank}: second-order grad missing")
205+
self.assertTrue(
206+
(x.grad.abs() > 1e-6).any(),
207+
f"Rank {self.rank}: second-order grad should be non-zero",
208+
)
209+
210+
def test_gradgradcheck_fp64_world_group(self):
211+
"""Gradgradcheck should exercise _AllToAllDouble with WORLD group."""
212+
torch.manual_seed(20260518)
213+
if self.device.type == "cuda":
214+
torch.cuda.manual_seed_all(20260518)
215+
216+
send_splits = [1] * self.world_size
217+
recv_splits = [1] * self.world_size
218+
x = torch.randn(
219+
self.world_size,
220+
2,
221+
dtype=torch.float64,
222+
device=self.device,
223+
requires_grad=True,
224+
)
225+
226+
def func(inp):
227+
out = all_to_all_differentiable(
228+
inp, send_splits, recv_splits, group=self.group
229+
)
230+
# Pick the row sourced from this rank so per-rank gradgradcheck
231+
# perturbs only the input that can affect the local output.
232+
return out.narrow(0, self.rank, 1)
233+
234+
result = torch.autograd.gradgradcheck(
235+
func,
236+
(x,),
237+
eps=1e-6,
238+
atol=1e-4,
239+
raise_exception=False,
240+
)
241+
self.assertTrue(
242+
result,
243+
f"Rank {self.rank}: distributed gradgradcheck failed",
244+
)
245+
201246

202247
def run_tests():
203248
"""Run all tests and report results."""
204249
import sys
205250

206-
rank, world_size, _ = setup_dist()
251+
rank, world_size, device = setup_dist()
207252

208253
# Only rank 0 prints header
209254
if rank == 0:
@@ -217,18 +262,20 @@ def run_tests():
217262
result = runner.run(suite)
218263

219264
# Synchronize results across ranks (before cleanup)
220-
success = torch.tensor([1 if result.wasSuccessful() else 0], dtype=torch.int32)
265+
success = torch.tensor(
266+
[1 if result.wasSuccessful() else 0], dtype=torch.int32, device=device
267+
)
221268
if dist.is_initialized():
222269
dist.all_reduce(success, op=dist.ReduceOp.MIN)
223270

224271
if rank == 0:
225272
if success.item() == 1:
226273
sys.stdout.write(f"\n{'=' * 70}\n")
227-
sys.stdout.write(f"✓ All tests passed on all {world_size} ranks\n")
274+
sys.stdout.write(f"PASS: all tests passed on all {world_size} ranks\n")
228275
sys.stdout.write(f"{'=' * 70}\n\n")
229276
else:
230277
sys.stdout.write(f"\n{'=' * 70}\n")
231-
sys.stdout.write("✗ Tests failed on at least one rank\n")
278+
sys.stdout.write("FAIL: tests failed on at least one rank\n")
232279
sys.stdout.write(f"{'=' * 70}\n\n")
233280

234281
cleanup_dist()

0 commit comments

Comments
 (0)