Skip to content

Commit e9c71d0

Browse files
committed
Reapply "step1: all to all"
This reverts commit 64bfcc6.
1 parent 64bfcc6 commit e9c71d0

5 files changed

Lines changed: 515 additions & 0 deletions

File tree

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ ______________________________________________________________________
4444
1. **复用 DPA3 代码的方式****copy 一份过来改**,不要 `from deepmd-kit-moe.xxx import ...`。需要参考时去 `/mnt/data_nas/zhangd/claude_space/deepmd-kit-moe` 找。详见子 agent `dpa3-ref-searcher`
4545
1. **每个 Step 必须配套 UT**:不写 UT 不能进下一个 Step。UT 通过 → 才能集成。详见 `SPEC.md` §6 的测试矩阵。
4646
1. **多卡 UT 用 torchrun 跑**:模板见 skill `multi-gpu-test-template`
47+
1. **代码风格检查**:每个 Step 完成后必须运行 `ruff check` 并修复所有问题,然后重新验证测试通过。Ruff 路径:`/root/miniconda3/bin/ruff`
4748

4849
______________________________________________________________________
4950

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""SeZM MoE (Mixture-of-Experts) modules for Expert Parallelism + Data Parallelism.
3+
4+
This package implements MoE components for the SeZM descriptor:
5+
- Communication primitives (A2A with second-order derivatives)
6+
- Router (top-k gating)
7+
- Expert collections (routing + shared experts)
8+
- MoE convolution layer (replaces SO2 linear stack)
9+
"""
10+
11+
from __future__ import (
12+
annotations,
13+
)
14+
15+
__all__ = []
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Differentiable All-to-All communication operators for SeZM MoE Expert Parallelism.
3+
4+
Provides `_AllToAllDouble`, a recursive autograd Function whose backward
5+
calls `.apply()` again, creating a fresh autograd node so that
6+
`create_graph=True` (required for force -> virial second derivatives)
7+
works correctly to arbitrary order.
8+
9+
Public API
10+
----------
11+
all_to_all_differentiable(x, send_splits, recv_splits, group)
12+
When *group* is ``None`` (single-GPU / no EP), returns *x* unchanged.
13+
Otherwise dispatches through ``_AllToAllDouble``.
14+
"""
15+
16+
from __future__ import (
17+
annotations,
18+
)
19+
20+
from typing import (
21+
Any,
22+
)
23+
24+
import torch
25+
import torch.distributed as dist
26+
from torch.autograd import (
27+
Function,
28+
)
29+
30+
31+
def _a2a_raw(
32+
x: torch.Tensor,
33+
send_splits: list[int],
34+
recv_splits: list[int],
35+
group: dist.ProcessGroup,
36+
) -> torch.Tensor:
37+
"""Raw All-to-All without autograd.
38+
39+
Parameters
40+
----------
41+
x : Tensor
42+
Input tensor whose first dimension equals ``sum(send_splits)``.
43+
send_splits : list[int]
44+
Number of rows to send to each rank.
45+
recv_splits : list[int]
46+
Number of rows to receive from each rank.
47+
group : ProcessGroup
48+
The communication group.
49+
50+
Returns
51+
-------
52+
Tensor
53+
Output tensor with first dimension ``sum(recv_splits)``.
54+
"""
55+
total_recv = sum(recv_splits)
56+
out = torch.empty((total_recv, *x.shape[1:]), dtype=x.dtype, device=x.device)
57+
dist.all_to_all_single(
58+
out,
59+
x.contiguous(),
60+
output_split_sizes=recv_splits,
61+
input_split_sizes=send_splits,
62+
group=group,
63+
)
64+
return out
65+
66+
67+
class _AllToAllDouble(Function):
68+
"""Recursively differentiable All-to-All.
69+
70+
The backward pass calls ``.apply()`` with swapped send/recv splits,
71+
which creates a *new* autograd node. This means the graph built by
72+
``create_graph=True`` (1st backward) can itself be differentiated
73+
(2nd backward), giving correct second-order derivatives through
74+
the communication boundary.
75+
76+
The layer-sequential structure of SeZM guarantees that all ranks
77+
execute A2A calls in the same order, so deadlocks cannot occur.
78+
"""
79+
80+
@staticmethod
81+
def forward(
82+
ctx: Any,
83+
x: torch.Tensor,
84+
send_splits: list[int],
85+
recv_splits: list[int],
86+
group: dist.ProcessGroup,
87+
) -> torch.Tensor:
88+
ctx.group = group
89+
ctx.send_splits = send_splits
90+
ctx.recv_splits = recv_splits
91+
return _a2a_raw(x, send_splits, recv_splits, group)
92+
93+
@staticmethod
94+
def backward(
95+
ctx: Any, grad_output: torch.Tensor
96+
) -> tuple[torch.Tensor, None, None, None]:
97+
# Recursive call: backward of this node is itself an A2A with
98+
# swapped splits. Because we call .apply(), a new autograd node
99+
# is inserted into the graph, enabling higher-order derivatives.
100+
grad_input = _AllToAllDouble.apply(
101+
grad_output,
102+
ctx.recv_splits,
103+
ctx.send_splits,
104+
ctx.group,
105+
)
106+
return grad_input, None, None, None
107+
108+
109+
def all_to_all_differentiable(
110+
x: torch.Tensor,
111+
send_splits: list[int],
112+
recv_splits: list[int],
113+
group: dist.ProcessGroup | None,
114+
) -> torch.Tensor:
115+
"""Public API for differentiable All-to-All.
116+
117+
Parameters
118+
----------
119+
x : Tensor
120+
Input tensor.
121+
send_splits : list[int]
122+
Number of rows to send to each rank.
123+
recv_splits : list[int]
124+
Number of rows to receive from each rank.
125+
group : ProcessGroup or None
126+
Communication group. When ``None`` (single-GPU / no EP),
127+
*x* is returned unchanged with gradients flowing through.
128+
129+
Returns
130+
-------
131+
Tensor
132+
Result of All-to-All, or *x* itself when ``group is None``.
133+
"""
134+
if group is None:
135+
return x
136+
return _AllToAllDouble.apply(x, send_splits, recv_splits, group)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Unit tests for SeZM MoE All-to-All communication primitive (single-GPU)."""
3+
4+
import unittest
5+
6+
import torch
7+
8+
from deepmd.pt.model.descriptor.sezm_nn.moe.a2a_ops import (
9+
all_to_all_differentiable,
10+
)
11+
12+
13+
class TestAllToAllSingleGPU(unittest.TestCase):
14+
"""Single-GPU tests for _AllToAllDouble communication primitive."""
15+
16+
def test_single_gpu_passthrough(self):
17+
"""group=None should return x unchanged with gradients flowing through."""
18+
x = torch.randn(10, 8, requires_grad=True, device="cpu")
19+
send_splits = [3, 3, 4]
20+
recv_splits = [2, 5, 3]
21+
22+
out = all_to_all_differentiable(x, send_splits, recv_splits, group=None)
23+
24+
# Output should be identical to input
25+
self.assertIs(out, x, "group=None should return input tensor unchanged")
26+
27+
# Gradient should flow through
28+
loss = out.sum()
29+
loss.backward()
30+
self.assertIsNotNone(x.grad, "Gradient should flow through when group=None")
31+
self.assertTrue(
32+
torch.allclose(x.grad, torch.ones_like(x)),
33+
"Gradient should be all ones for sum() loss",
34+
)
35+
36+
def test_shape_preservation(self):
37+
"""Forward pass should preserve trailing dimensions."""
38+
# Test various shapes
39+
test_cases = [
40+
((10, 8), [3, 3, 4], [2, 5, 3]),
41+
((15, 16, 32), [5, 5, 5], [4, 6, 5]),
42+
((8, 4, 4, 64), [2, 3, 3], [3, 2, 3]),
43+
]
44+
45+
for shape, send_splits, recv_splits in test_cases:
46+
with self.subTest(shape=shape):
47+
x = torch.randn(*shape, device="cpu")
48+
out = all_to_all_differentiable(x, send_splits, recv_splits, group=None)
49+
50+
# First dimension should match sum(recv_splits)
51+
expected_shape = (sum(recv_splits), *shape[1:])
52+
self.assertEqual(
53+
out.shape,
54+
expected_shape,
55+
f"Output shape mismatch for input shape {shape}",
56+
)
57+
58+
def test_first_backward(self):
59+
"""loss.backward() should produce non-zero gradients."""
60+
x = torch.randn(10, 8, requires_grad=True, device="cpu")
61+
send_splits = [3, 3, 4]
62+
recv_splits = [2, 5, 3]
63+
64+
out = all_to_all_differentiable(x, send_splits, recv_splits, group=None)
65+
loss = (out**2).sum()
66+
loss.backward()
67+
68+
self.assertIsNotNone(x.grad, "Gradient should exist after backward")
69+
self.assertTrue(
70+
(x.grad.abs() > 1e-6).any(), "Gradient should contain non-zero values"
71+
)
72+
73+
def test_second_backward(self):
74+
"""create_graph=True + second backward should produce non-zero gradients."""
75+
x = torch.randn(10, 8, requires_grad=True, device="cpu")
76+
send_splits = [3, 3, 4]
77+
recv_splits = [2, 5, 3]
78+
79+
# First forward
80+
out = all_to_all_differentiable(x, send_splits, recv_splits, group=None)
81+
loss = (out**2).sum()
82+
83+
# First backward with create_graph=True
84+
(grad_x,) = torch.autograd.grad(loss, x, create_graph=True, retain_graph=True)
85+
86+
self.assertIsNotNone(grad_x, "First-order gradient should exist")
87+
self.assertTrue(
88+
grad_x.requires_grad, "First-order gradient should require grad"
89+
)
90+
91+
# Second backward
92+
loss2 = (grad_x**2).sum()
93+
loss2.backward()
94+
95+
self.assertIsNotNone(x.grad, "Second-order gradient should exist")
96+
self.assertTrue(
97+
(x.grad.abs() > 1e-6).any(),
98+
"Second-order gradient should contain non-zero values",
99+
)
100+
101+
def test_gradgradcheck_fp64(self):
102+
"""torch.autograd.gradgradcheck should pass in fp64."""
103+
# Use smaller tensors for gradgradcheck (it's expensive)
104+
x = torch.randn(6, 4, dtype=torch.float64, requires_grad=True, device="cpu")
105+
send_splits = [2, 2, 2]
106+
recv_splits = [1, 3, 2]
107+
108+
def func(inp):
109+
return all_to_all_differentiable(inp, send_splits, recv_splits, group=None)
110+
111+
# gradgradcheck verifies second-order derivatives
112+
result = torch.autograd.gradgradcheck(
113+
func, x, eps=1e-6, atol=1e-4, rtol=1e-3, raise_exception=False
114+
)
115+
116+
self.assertTrue(
117+
result, "gradgradcheck failed: second-order derivatives are incorrect"
118+
)
119+
120+
121+
if __name__ == "__main__":
122+
unittest.main()

0 commit comments

Comments
 (0)