Skip to content

Commit fca1939

Browse files
committed
step8: ep+dp gradient
1 parent f47818c commit fca1939

2 files changed

Lines changed: 737 additions & 0 deletions

File tree

deepmd/pt/utils/moe_ep_dp.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""MoE Expert-Parallelism + Data-Parallelism process group management.
3+
4+
Provides:
5+
- ``init_ep_dp_groups``: create EP and DP process groups from a flat world.
6+
- ``sync_moe_gradients``: all-reduce gradients with correct group/divisor.
7+
- ``_is_routing_expert_param``: classify parameter names.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
import torch
13+
import torch.distributed as dist
14+
15+
16+
def init_ep_dp_groups(
17+
ep_size: int = 1,
18+
) -> tuple[object | None, object | None, int, int, int, int]:
19+
"""Initialize EP and DP process groups from a flat world.
20+
21+
The world of ``world_size`` GPUs is viewed as a 2-D grid::
22+
23+
world_size = ep_size × dp_size
24+
25+
GPU layout (ep_size=2, dp_size=2, world_size=4):
26+
27+
EP rank 0 EP rank 1
28+
DP rank 0: GPU 0 GPU 1 ← ep_group_0
29+
DP rank 1: GPU 2 GPU 3 ← ep_group_1
30+
↑ dp_group_0 ↑ dp_group_1
31+
32+
Parameters
33+
----------
34+
ep_size : int
35+
Number of GPUs per expert-parallel group. When ``ep_size <= 1``
36+
or distributed is not initialised, no groups are created.
37+
38+
Returns
39+
-------
40+
ep_group : ProcessGroup or None
41+
The EP group this rank belongs to (for All-to-All).
42+
dp_group : ProcessGroup or None
43+
The DP group this rank belongs to (for routing-expert gradient sync).
44+
ep_rank : int
45+
This rank's position inside its EP group.
46+
ep_size : int
47+
Size of the EP group (echoed back, or 1 if disabled).
48+
dp_rank : int
49+
This rank's position inside its DP group.
50+
dp_size : int
51+
Size of the DP group.
52+
"""
53+
if ep_size <= 1 or not dist.is_initialized():
54+
world_size = dist.get_world_size() if dist.is_initialized() else 1
55+
rank = dist.get_rank() if dist.is_initialized() else 0
56+
return (None, None, 0, 1, rank, world_size)
57+
58+
world_size = dist.get_world_size()
59+
world_rank = dist.get_rank()
60+
61+
if world_size % ep_size != 0:
62+
raise ValueError(
63+
f"world_size ({world_size}) must be divisible by ep_size ({ep_size})"
64+
)
65+
66+
dp_size = world_size // ep_size
67+
68+
# Build EP groups: each row of the GPU grid.
69+
# ALL ranks must call new_group for every group (NCCL requirement).
70+
my_ep_group = None
71+
for dp_idx in range(dp_size):
72+
ranks = [dp_idx * ep_size + i for i in range(ep_size)]
73+
group = dist.new_group(ranks)
74+
if world_rank in ranks:
75+
my_ep_group = group
76+
77+
# Build DP groups: each column of the GPU grid.
78+
my_dp_group = None
79+
for ep_idx in range(ep_size):
80+
ranks = [dp_idx * ep_size + ep_idx for dp_idx in range(dp_size)]
81+
group = dist.new_group(ranks)
82+
if world_rank in ranks:
83+
my_dp_group = group
84+
85+
ep_rank = world_rank % ep_size
86+
dp_rank = world_rank // ep_size
87+
88+
return (my_ep_group, my_dp_group, ep_rank, ep_size, dp_rank, dp_size)
89+
90+
91+
def _is_routing_expert_param(name: str) -> bool:
92+
"""Check whether a parameter belongs to a routing expert.
93+
94+
Routing expert parameters contain ``.routing_experts.`` in their
95+
fully-qualified name. Examples::
96+
97+
moe_phase1.node_self_experts.routing_experts.0.mlp.matrix → True
98+
moe_phase1.edge_experts.shared_experts.0.mlp.matrix → False
99+
node_router.gate.matrix → False
100+
n_residual.0 → False
101+
"""
102+
return ".routing_experts." in name
103+
104+
105+
def sync_moe_gradients(
106+
model: torch.nn.Module,
107+
dp_group: object | None,
108+
world_group: object | None,
109+
dp_size: int,
110+
world_size: int,
111+
) -> None:
112+
"""All-reduce gradients with the correct group and divisor.
113+
114+
Must be called **after** ``loss.backward()`` and **before**
115+
``optimizer.step()``.
116+
117+
Parameters
118+
----------
119+
model : torch.nn.Module
120+
The model whose parameter gradients should be synchronised.
121+
dp_group : ProcessGroup or None
122+
DP group for routing-expert gradient all-reduce.
123+
world_group : ProcessGroup or None
124+
World group for all other parameters. ``None`` uses the
125+
default process group (all ranks).
126+
dp_size : int
127+
Number of ranks in the DP group.
128+
world_size : int
129+
Total number of ranks.
130+
"""
131+
for name, param in model.named_parameters():
132+
if param.grad is None:
133+
continue
134+
if _is_routing_expert_param(name):
135+
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=dp_group)
136+
param.grad.div_(dp_size)
137+
else:
138+
dist.all_reduce(
139+
param.grad, op=dist.ReduceOp.SUM, group=world_group
140+
)
141+
param.grad.div_(world_size)

0 commit comments

Comments
 (0)