Skip to content

Commit 056d784

Browse files
committed
step9: add args to dpa3
1 parent fca1939 commit 056d784

4 files changed

Lines changed: 508 additions & 0 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ class RepFlowArgs:
170170
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
171171
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
172172
accommodating larger selection numbers.
173+
use_moe : bool, optional
174+
Whether to use Mixture-of-Experts for the MLP layers in each RepFlowLayer.
175+
n_routing_experts : int, optional
176+
Total number of routing experts across all GPUs.
177+
moe_topk : int, optional
178+
Number of experts selected per token.
179+
n_shared_experts : int, optional
180+
Number of shared experts (replicated on every GPU).
173181
"""
174182

175183
def __init__(
@@ -201,6 +209,10 @@ def __init__(
201209
use_exp_switch: bool = False,
202210
use_dynamic_sel: bool = False,
203211
sel_reduce_factor: float = 10.0,
212+
use_moe: bool = False,
213+
n_routing_experts: int = 0,
214+
moe_topk: int = 0,
215+
n_shared_experts: int = 0,
204216
) -> None:
205217
self.n_dim = n_dim
206218
self.e_dim = e_dim
@@ -231,6 +243,10 @@ def __init__(
231243
self.use_exp_switch = use_exp_switch
232244
self.use_dynamic_sel = use_dynamic_sel
233245
self.sel_reduce_factor = sel_reduce_factor
246+
self.use_moe = use_moe
247+
self.n_routing_experts = n_routing_experts
248+
self.moe_topk = moe_topk
249+
self.n_shared_experts = n_shared_experts
234250

235251
def __getitem__(self, key: str) -> Any:
236252
if hasattr(self, key):
@@ -266,6 +282,10 @@ def serialize(self) -> dict:
266282
"use_exp_switch": self.use_exp_switch,
267283
"use_dynamic_sel": self.use_dynamic_sel,
268284
"sel_reduce_factor": self.sel_reduce_factor,
285+
"use_moe": self.use_moe,
286+
"n_routing_experts": self.n_routing_experts,
287+
"moe_topk": self.moe_topk,
288+
"n_shared_experts": self.n_shared_experts,
269289
}
270290

271291
@classmethod

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def __init__(
122122
use_loc_mapping: bool = True,
123123
type_map: list[str] | None = None,
124124
add_chg_spin_ebd: bool = False,
125+
# MoE EP params (not part of RepFlowArgs, set at runtime).
126+
ep_group=None,
127+
ep_rank: int = 0,
128+
ep_size: int = 1,
125129
) -> None:
126130
super().__init__()
127131

@@ -173,6 +177,13 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any:
173177
precision=precision,
174178
seed=child_seed(seed, 1),
175179
trainable=trainable,
180+
use_moe=self.repflow_args.use_moe,
181+
n_routing_experts=self.repflow_args.n_routing_experts,
182+
moe_topk=self.repflow_args.moe_topk,
183+
n_shared_experts=self.repflow_args.n_shared_experts,
184+
ep_group=ep_group,
185+
ep_rank=ep_rank,
186+
ep_size=ep_size,
176187
)
177188

178189
self.use_econf_tebd = use_econf_tebd

deepmd/pt/model/descriptor/repflows.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,13 @@ def __init__(
226226
optim_update: bool = True,
227227
seed: int | list[int] | None = None,
228228
trainable: bool = True,
229+
use_moe: bool = False,
230+
n_routing_experts: int = 0,
231+
moe_topk: int = 0,
232+
n_shared_experts: int = 0,
233+
ep_group=None,
234+
ep_rank: int = 0,
235+
ep_size: int = 1,
229236
) -> None:
230237
super().__init__()
231238
self.e_rcut = float(e_rcut)
@@ -256,6 +263,7 @@ def __init__(
256263
self.a_compress_use_split = a_compress_use_split
257264
self.use_loc_mapping = use_loc_mapping
258265
self.optim_update = optim_update
266+
self.use_moe = use_moe
259267
self.smooth_edge_update = smooth_edge_update
260268
self.edge_init_use_dist = edge_init_use_dist
261269
self.use_exp_switch = use_exp_switch
@@ -335,6 +343,13 @@ def __init__(
335343
smooth_edge_update=self.smooth_edge_update,
336344
seed=child_seed(child_seed(seed, 1), ii),
337345
trainable=trainable,
346+
use_moe=use_moe,
347+
n_routing_experts=n_routing_experts,
348+
moe_topk=moe_topk,
349+
n_shared_experts=n_shared_experts,
350+
ep_group=ep_group,
351+
ep_rank=ep_rank,
352+
ep_size=ep_size,
338353
)
339354
)
340355
self.layers = torch.nn.ModuleList(layers)
@@ -656,6 +671,7 @@ def forward(
656671
a_sw,
657672
edge_index=edge_index,
658673
angle_index=angle_index,
674+
type_embedding=atype_embd if self.use_moe else None,
659675
)
660676

661677
# nb x nloc x 3 x e_dim

0 commit comments

Comments
 (0)