Skip to content

Commit ff588a8

Browse files
committed
feat: add dynamic sel
1 parent d6a66e3 commit ff588a8

7 files changed

Lines changed: 904 additions & 106 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(
2525
update_residual_init: str = "const",
2626
skip_stat: bool = False,
2727
optim_update: bool = True,
28+
smooth_edge_update: bool = False,
29+
use_dynamic_sel: bool = False,
30+
sel_reduce_factor: float = 10.0,
2831
) -> None:
2932
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
3033
@@ -102,6 +105,9 @@ def __init__(
102105
self.a_compress_e_rate = a_compress_e_rate
103106
self.a_compress_use_split = a_compress_use_split
104107
self.optim_update = optim_update
108+
self.smooth_edge_update = smooth_edge_update
109+
self.use_dynamic_sel = use_dynamic_sel
110+
self.sel_reduce_factor = sel_reduce_factor
105111

106112
def __getitem__(self, key):
107113
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def init_subclass_params(sub_data, sub_class):
164164
update_residual_init=self.repflow_args.update_residual_init,
165165
optim_update=self.repflow_args.optim_update,
166166
skip_stat=self.repflow_args.skip_stat,
167+
smooth_edge_update=self.repflow_args.smooth_edge_update,
168+
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
169+
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
167170
exclude_types=exclude_types,
168171
env_protection=env_protection,
169172
precision=precision,
@@ -192,8 +195,8 @@ def init_subclass_params(sub_data, sub_class):
192195
self.env_protection = env_protection
193196
self.trainable = trainable
194197

195-
assert self.repflows.e_rcut > self.repflows.a_rcut
196-
assert self.repflows.e_sel > self.repflows.a_sel
198+
assert self.repflows.e_rcut >= self.repflows.a_rcut
199+
assert self.repflows.e_sel >= self.repflows.a_sel
197200

198201
self.rcut = self.repflows.get_rcut()
199202
self.rcut_smth = self.repflows.get_rcut_smth()

0 commit comments

Comments
 (0)