Skip to content

Commit c4ed1f9

Browse files
committed
feat(pt/dpmodel): add sequential_update for dpa3
1 parent a7e9fed commit c4ed1f9

8 files changed

Lines changed: 917 additions & 0 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ class RepFlowArgs:
170170
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
171171
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
172172
accommodating larger selection numbers.
173+
sequential_update : bool, optional
174+
Whether to use sequential update mode within each repflow layer.
175+
When True, updates are applied sequentially: edge self → angle self (using updated edge)
176+
→ edge angle (using updated angle) → node (using final edge),
177+
instead of the default parallel mode where all updates use original embeddings.
178+
Currently only supports ``update_style='res_residual'``.
173179
"""
174180

175181
def __init__(
@@ -201,6 +207,7 @@ def __init__(
201207
use_exp_switch: bool = False,
202208
use_dynamic_sel: bool = False,
203209
sel_reduce_factor: float = 10.0,
210+
sequential_update: bool = False,
204211
) -> None:
205212
self.n_dim = n_dim
206213
self.e_dim = e_dim
@@ -231,6 +238,15 @@ def __init__(
231238
self.use_exp_switch = use_exp_switch
232239
self.use_dynamic_sel = use_dynamic_sel
233240
self.sel_reduce_factor = sel_reduce_factor
241+
self.sequential_update = sequential_update
242+
if self.sequential_update:
243+
if self.update_style != "res_residual":
244+
raise ValueError(
245+
"sequential_update only supports update_style='res_residual', "
246+
f"got '{self.update_style}'!"
247+
)
248+
if not self.update_angle:
249+
raise ValueError("sequential_update requires update_angle=True!")
234250

235251
def __getitem__(self, key: str) -> Any:
236252
if hasattr(self, key):
@@ -266,6 +282,7 @@ def serialize(self) -> dict:
266282
"use_exp_switch": self.use_exp_switch,
267283
"use_dynamic_sel": self.use_dynamic_sel,
268284
"sel_reduce_factor": self.sel_reduce_factor,
285+
"sequential_update": self.sequential_update,
269286
}
270287

271288
@classmethod
@@ -404,6 +421,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
404421
use_exp_switch=self.repflow_args.use_exp_switch,
405422
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
406423
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
424+
sequential_update=self.repflow_args.sequential_update,
407425
use_loc_mapping=use_loc_mapping,
408426
exclude_types=exclude_types,
409427
env_protection=env_protection,

0 commit comments

Comments
 (0)