Skip to content

Commit 8a350d9

Browse files
committed
fix uts
1 parent 5dd2d04 commit 8a350d9

6 files changed

Lines changed: 22 additions & 4 deletions

File tree

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,7 @@ def call(
12941294
)
12951295
nb, nloc, nnei = nlist.shape
12961296
nall = node_ebd_ext.shape[1]
1297-
n_edge = int(nlist_mask.sum().item())
1297+
n_edge = int(xp.sum(nlist_mask.astype(xp.int32)))
12981298
node_ebd = node_ebd_ext[:, :nloc, :]
12991299
assert (nb, nloc) == node_ebd.shape[:2]
13001300
if not self.use_dynamic_sel:

deepmd/dpmodel/utils/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ def get_graph_index(
10441044
a_nlist_mask[:, :, :, None], a_nlist_mask[:, :, None, :]
10451045
)
10461046

1047-
n_edge = int(xp.asarray(xp.sum(nlist_mask)))
1047+
n_edge = int(xp.sum(nlist_mask.astype(xp.int32)))
10481048

10491049
# following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index
10501050

deepmd/pd/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def init_subclass_params(sub_data, sub_class):
150150
fix_stat_std=self.repflow_args.fix_stat_std,
151151
optim_update=self.repflow_args.optim_update,
152152
smooth_edge_update=self.repflow_args.smooth_edge_update,
153+
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
154+
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
153155
exclude_types=exclude_types,
154156
env_protection=env_protection,
155157
precision=precision,

deepmd/pd/model/descriptor/repflow_layer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(
5252
axis_neuron: int = 4,
5353
update_angle: bool = True,
5454
optim_update: bool = True,
55+
use_dynamic_sel: bool = False,
56+
sel_reduce_factor: float = 10.0,
5557
smooth_edge_update: bool = False,
5658
activation_function: str = "silu",
5759
update_style: str = "res_residual",
@@ -98,6 +100,10 @@ def __init__(
98100
self.prec = PRECISION_DICT[precision]
99101
self.optim_update = optim_update
100102
self.smooth_edge_update = smooth_edge_update
103+
self.use_dynamic_sel = use_dynamic_sel
104+
self.sel_reduce_factor = sel_reduce_factor
105+
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
106+
self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
101107

102108
assert update_residual_init in [
103109
"norm",
@@ -812,7 +818,7 @@ def serialize(self) -> dict:
812818
"""
813819
data = {
814820
"@class": "RepFlowLayer",
815-
"@version": 1,
821+
"@version": 2,
816822
"e_rcut": self.e_rcut,
817823
"e_rcut_smth": self.e_rcut_smth,
818824
"e_sel": self.e_sel,
@@ -836,6 +842,8 @@ def serialize(self) -> dict:
836842
"precision": self.precision,
837843
"optim_update": self.optim_update,
838844
"smooth_edge_update": self.smooth_edge_update,
845+
"use_dynamic_sel": self.use_dynamic_sel,
846+
"sel_reduce_factor": self.sel_reduce_factor,
839847
"node_self_mlp": self.node_self_mlp.serialize(),
840848
"node_sym_linear": self.node_sym_linear.serialize(),
841849
"node_edge_linear": self.node_edge_linear.serialize(),
@@ -878,7 +886,7 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
878886
The dict to deserialize from.
879887
"""
880888
data = data.copy()
881-
check_version_compatibility(data.pop("@version"), 1, 1)
889+
check_version_compatibility(data.pop("@version"), 2, 1)
882890
data.pop("@class")
883891
update_angle = data["update_angle"]
884892
a_compress_rate = data["a_compress_rate"]

deepmd/pd/model/descriptor/repflows.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def __init__(
159159
precision: str = "float64",
160160
fix_stat_std: float = 0.3,
161161
smooth_edge_update: bool = False,
162+
use_dynamic_sel: bool = False,
163+
sel_reduce_factor: float = 10.0,
162164
optim_update: bool = True,
163165
seed: Optional[Union[int, list[int]]] = None,
164166
) -> None:
@@ -191,6 +193,9 @@ def __init__(
191193
self.a_compress_use_split = a_compress_use_split
192194
self.optim_update = optim_update
193195
self.smooth_edge_update = smooth_edge_update
196+
self.use_dynamic_sel = use_dynamic_sel # not supported yet
197+
self.sel_reduce_factor = sel_reduce_factor
198+
assert not self.use_dynamic_sel, "Dynamic selection is not supported yet."
194199

195200
self.n_dim = n_dim
196201
self.e_dim = e_dim
@@ -243,6 +248,8 @@ def __init__(
243248
update_residual_init=self.update_residual_init,
244249
precision=precision,
245250
optim_update=self.optim_update,
251+
use_dynamic_sel=self.use_dynamic_sel,
252+
sel_reduce_factor=self.sel_reduce_factor,
246253
smooth_edge_update=self.smooth_edge_update,
247254
seed=child_seed(child_seed(seed, 1), ii),
248255
)

source/tests/universal/dpmodel/descriptor/test_descriptor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def DescriptorParamDPA3(
509509
"n_multi_edge_message": n_multi_edge_message,
510510
"axis_neuron": 2,
511511
"use_dynamic_sel": use_dynamic_sel,
512+
"sel_reduce_factor": 1.0,
512513
"update_angle": update_angle,
513514
"update_style": update_style,
514515
"update_residual": update_residual,

0 commit comments

Comments
 (0)