Skip to content

Commit 53cd84c

Browse files
committed
rename args and add sel_reduce_factor
1 parent fd307c9 commit 53cd84c

6 files changed

Lines changed: 63 additions & 46 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __init__(
3232
angle_init_use_sin: bool = False,
3333
smooth_edge_update: bool = False,
3434
angle_multi_freq: Optional[str] = None,
35-
no_sel: bool = False,
35+
use_dynamic_sel: bool = False,
36+
sel_reduce_factor: float = 10.0,
3637
) -> None:
3738
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
3839
@@ -114,7 +115,8 @@ def __init__(
114115
self.angle_init_use_sin = angle_init_use_sin
115116
self.smooth_edge_update = smooth_edge_update
116117
self.angle_multi_freq = angle_multi_freq
117-
self.no_sel = no_sel
118+
self.use_dynamic_sel = use_dynamic_sel
119+
self.sel_reduce_factor = sel_reduce_factor
118120

119121
def __getitem__(self, key):
120122
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def init_subclass_params(sub_data, sub_class):
167167
angle_init_use_sin=self.repflow_args.angle_init_use_sin,
168168
smooth_edge_update=self.repflow_args.smooth_edge_update,
169169
angle_multi_freq=self.repflow_args.angle_multi_freq,
170-
no_sel=self.repflow_args.no_sel,
170+
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
171+
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
171172
exclude_types=exclude_types,
172173
env_protection=env_protection,
173174
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def __init__(
5555
axis_neuron: int = 4,
5656
update_angle: bool = True, # angle
5757
optim_update: bool = True,
58-
no_sel: bool = False,
58+
use_dynamic_sel: bool = False,
59+
sel_reduce_factor: float = 10.0,
5960
smooth_edge_update: bool = False,
6061
activation_function: str = "silu",
6162
update_style: str = "res_residual",
@@ -102,7 +103,10 @@ def __init__(
102103
self.prec = PRECISION_DICT[precision]
103104
self.optim_update = optim_update
104105
self.smooth_edge_update = smooth_edge_update
105-
self.no_sel = no_sel
106+
self.use_dynamic_sel = use_dynamic_sel
107+
self.sel_reduce_factor = sel_reduce_factor
108+
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
109+
self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
106110

107111
assert update_residual_init in [
108112
"norm",
@@ -324,7 +328,7 @@ def _cal_hg(
324328
return h2g2
325329

326330
@staticmethod
327-
def _cal_hg_nosel(
331+
def _cal_hg_dynamic(
328332
flat_edge_ebd: torch.Tensor,
329333
flat_h2: torch.Tensor,
330334
flat_sw: torch.Tensor,
@@ -447,7 +451,7 @@ def symmetrization_op(
447451
g1_13 = self._cal_grrg(h2g2, axis_neuron)
448452
return g1_13
449453

450-
def symmetrization_op_nosel(
454+
def symmetrization_op_dynamic(
451455
self,
452456
flat_edge_ebd: torch.Tensor,
453457
flat_h2: torch.Tensor,
@@ -487,7 +491,7 @@ def symmetrization_op_nosel(
487491
Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim)
488492
"""
489493
# nb x nloc x 3 x e_dim
490-
h2g2 = self._cal_hg_nosel(
494+
h2g2 = self._cal_hg_dynamic(
491495
flat_edge_ebd,
492496
flat_h2,
493497
flat_sw,
@@ -552,7 +556,7 @@ def optim_angle_update(
552556
) + bias
553557
return result_update
554558

555-
def optim_angle_update_nosel(
559+
def optim_angle_update_dynamic(
556560
self,
557561
flat_angle_ebd: torch.Tensor,
558562
node_ebd: torch.Tensor,
@@ -655,7 +659,7 @@ def optim_edge_update(
655659
) + bias
656660
return result_update
657661

658-
def optim_edge_update_nosel(
662+
def optim_edge_update_dynamic(
659663
self,
660664
node_ebd: torch.Tensor,
661665
node_ebd_ext: torch.Tensor,
@@ -770,7 +774,7 @@ def forward(
770774
node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1)
771775
n_edge = nlist_mask.sum().item()
772776
assert (nb, nloc) == node_ebd.shape[:2]
773-
if not self.no_sel:
777+
if not self.use_dynamic_sel:
774778
assert (nb, nloc, nnei, 3) == h2.shape
775779
else:
776780
assert (n_edge, 3) == h2.shape
@@ -786,7 +790,7 @@ def forward(
786790
# nb x nloc x nnei x n_dim [OR] n_edge x n_dim
787791
nei_node_ebd = (
788792
_make_nei_g1(node_ebd_ext, nlist)
789-
if not self.no_sel
793+
if not self.use_dynamic_sel
790794
else torch.index_select(
791795
node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index
792796
)
@@ -810,15 +814,15 @@ def forward(
810814
sw,
811815
self.axis_neuron,
812816
)
813-
if not self.no_sel
814-
else self.symmetrization_op_nosel(
817+
if not self.use_dynamic_sel
818+
else self.symmetrization_op_dynamic(
815819
edge_ebd,
816820
h2,
817821
sw,
818822
owner=n2e_index,
819823
num_owner=nb * nloc,
820824
nloc=nloc,
821-
scale_factor=1.0 / (float(nnei) ** 0.5),
825+
scale_factor=self.dynamic_e_sel ** (-0.5),
822826
axis_neuron=self.axis_neuron,
823827
)
824828
)
@@ -830,23 +834,23 @@ def forward(
830834
sw,
831835
self.axis_neuron,
832836
)
833-
if not self.no_sel
834-
else self.symmetrization_op_nosel(
837+
if not self.use_dynamic_sel
838+
else self.symmetrization_op_dynamic(
835839
nei_node_ebd,
836840
h2,
837841
sw,
838842
owner=n2e_index,
839843
num_owner=nb * nloc,
840844
nloc=nloc,
841-
scale_factor=1.0 / (float(nnei) ** 0.5),
845+
scale_factor=self.dynamic_e_sel ** (-0.5),
842846
axis_neuron=self.axis_neuron,
843847
)
844848
)
845849
node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1)))
846850
n_update_list.append(node_sym)
847851

848852
if not self.optim_update:
849-
if not self.no_sel:
853+
if not self.use_dynamic_sel:
850854
# nb x nloc x nnei x (n_dim * 2 + e_dim)
851855
edge_info = torch.cat(
852856
[
@@ -885,8 +889,8 @@ def forward(
885889
nlist,
886890
"node",
887891
)
888-
if not self.no_sel
889-
else self.optim_edge_update_nosel(
892+
if not self.use_dynamic_sel
893+
else self.optim_edge_update_dynamic(
890894
node_ebd,
891895
node_ebd_ext,
892896
edge_ebd,
@@ -897,15 +901,15 @@ def forward(
897901
)
898902
node_edge_update = (
899903
(torch.sum(node_edge_update * sw.unsqueeze(-1), dim=-2) / self.nnei)
900-
if not self.no_sel
904+
if not self.use_dynamic_sel
901905
else (
902906
aggregate(
903907
node_edge_update * sw.unsqueeze(-1),
904908
n2e_index,
905909
average=False,
906910
num_owner=nb * nloc,
907911
).reshape(nb, nloc, -1)
908-
/ self.nnei
912+
/ self.dynamic_e_sel
909913
)
910914
)
911915

@@ -934,8 +938,8 @@ def forward(
934938
nlist,
935939
"edge",
936940
)
937-
if not self.no_sel
938-
else self.optim_edge_update_nosel(
941+
if not self.use_dynamic_sel
942+
else self.optim_edge_update_dynamic(
939943
node_ebd,
940944
node_ebd_ext,
941945
edge_ebd,
@@ -965,7 +969,7 @@ def forward(
965969
node_ebd_for_angle = node_ebd
966970
edge_ebd_for_angle = edge_ebd
967971

968-
if not self.no_sel:
972+
if not self.use_dynamic_sel:
969973
# nb x nloc x a_nnei x e_dim
970974
edge_ebd_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :]
971975
# nb x nloc x a_nnei x e_dim
@@ -979,7 +983,7 @@ def forward(
979983
node_ebd_for_angle.unsqueeze(2).unsqueeze(2),
980984
(1, 1, self.a_sel, self.a_sel, 1),
981985
)
982-
if not self.no_sel
986+
if not self.use_dynamic_sel
983987
else torch.index_select(
984988
node_ebd_for_angle.reshape(-1, self.n_a_compress_dim),
985989
0,
@@ -992,15 +996,15 @@ def forward(
992996
torch.tile(
993997
edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)
994998
)
995-
if not self.no_sel
999+
if not self.use_dynamic_sel
9961000
else torch.index_select(edge_ebd_for_angle, 0, eik2a_index)
9971001
)
9981002
# nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim
9991003
edge_for_angle_j = (
10001004
torch.tile(
10011005
edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)
10021006
)
1003-
if not self.no_sel
1007+
if not self.use_dynamic_sel
10041008
else torch.index_select(edge_ebd_for_angle, 0, eij2a_index)
10051009
)
10061010
# nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim)
@@ -1030,8 +1034,8 @@ def forward(
10301034
edge_ebd_for_angle,
10311035
"edge",
10321036
)
1033-
if not self.no_sel
1034-
else self.optim_angle_update_nosel(
1037+
if not self.use_dynamic_sel
1038+
else self.optim_angle_update_dynamic(
10351039
angle_ebd,
10361040
node_ebd_for_angle,
10371041
edge_ebd_for_angle,
@@ -1042,7 +1046,7 @@ def forward(
10421046
)
10431047
)
10441048

1045-
if not self.no_sel:
1049+
if not self.use_dynamic_sel:
10461050
# nb x nloc x a_nnei x a_nnei x e_dim
10471051
weighted_edge_angle_update = (
10481052
edge_angle_update
@@ -1074,13 +1078,13 @@ def forward(
10741078
eij2a_index,
10751079
average=False,
10761080
num_owner=n_edge,
1077-
) / (self.a_sel**0.5)
1081+
) / (self.dynamic_a_sel**0.5)
10781082
if not self.smooth_edge_update:
10791083
# will be deprecated in the future
10801084
# not support dynamic index, will pass anyway
1081-
if self.no_sel:
1085+
if self.use_dynamic_sel:
10821086
raise NotImplementedError(
1083-
"smooth_edge_update must be True when using dynamic_sel!"
1087+
"smooth_edge_update must be True when use_dynamic_sel is True!"
10841088
)
10851089
full_mask = torch.concat(
10861090
[
@@ -1115,8 +1119,8 @@ def forward(
11151119
edge_ebd_for_angle,
11161120
"angle",
11171121
)
1118-
if not self.no_sel
1119-
else self.optim_angle_update_nosel(
1122+
if not self.use_dynamic_sel
1123+
else self.optim_angle_update_dynamic(
11201124
angle_ebd,
11211125
node_ebd_for_angle,
11221126
edge_ebd_for_angle,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def __init__(
107107
angle_init_use_sin: bool = False,
108108
smooth_edge_update: bool = False,
109109
angle_multi_freq: Optional[str] = None,
110-
no_sel: bool = False,
110+
use_dynamic_sel: bool = False,
111+
sel_reduce_factor: float = 10.0,
111112
optim_update: bool = True,
112113
seed: Optional[Union[int, list[int]]] = None,
113114
) -> None:
@@ -213,7 +214,8 @@ def __init__(
213214
self.smooth_angle_init = smooth_angle_init
214215
self.angle_init_use_sin = angle_init_use_sin
215216
self.smooth_edge_update = smooth_edge_update
216-
self.no_sel = no_sel
217+
self.use_dynamic_sel = use_dynamic_sel
218+
self.sel_reduce_factor = sel_reduce_factor
217219
self.angle_multi_freq = angle_multi_freq
218220
self.angle_use_multi_freq = angle_multi_freq is not None
219221
self.angle_multi_freq_list_float = (
@@ -290,7 +292,8 @@ def __init__(
290292
update_residual_init=self.update_residual_init,
291293
precision=precision,
292294
optim_update=self.optim_update,
293-
no_sel=self.no_sel,
295+
use_dynamic_sel=self.use_dynamic_sel,
296+
sel_reduce_factor=self.sel_reduce_factor,
294297
smooth_edge_update=self.smooth_edge_update,
295298
seed=child_seed(child_seed(seed, 1), ii),
296299
)
@@ -497,7 +500,7 @@ def forward(
497500
)
498501
angle_input = torch.cat(angle_input_list, dim=-1) / (torch.pi**0.5)
499502

500-
if self.no_sel:
503+
if self.use_dynamic_sel:
501504
# get graph index
502505
edge_index, angle_index = get_graph_index(
503506
nlist, nlist_mask, a_nlist_mask, nall
@@ -613,15 +616,15 @@ def forward(
613616
# nb x nloc x 3 x e_dim
614617
h2g2 = (
615618
RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw)
616-
if not self.no_sel
617-
else RepFlowLayer._cal_hg_nosel(
619+
if not self.use_dynamic_sel
620+
else RepFlowLayer._cal_hg_dynamic(
618621
edge_ebd,
619622
h2,
620623
sw,
621624
owner=edge_index[:, 0],
622625
num_owner=nframes * nloc,
623626
nloc=nloc,
624-
scale_factor=1.0 / (float(nnei) ** 0.5),
627+
scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5),
625628
)
626629
)
627630
# (nb x nloc) x e_dim x 3

deepmd/utils/argcheck.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1593,11 +1593,17 @@ def dpa3_repflow_args():
15931593
default=None,
15941594
),
15951595
Argument(
1596-
"no_sel",
1596+
"use_dynamic_sel",
15971597
bool,
15981598
optional=True,
15991599
default=False,
16001600
),
1601+
Argument(
1602+
"sel_reduce_factor",
1603+
float,
1604+
optional=True,
1605+
default=10.0,
1606+
),
16011607
]
16021608

16031609

source/tests/pt/model/test_nosel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_consistency(
8787
update_residual_init=ruri,
8888
optim_update=optim,
8989
smooth_edge_update=True,
90+
sel_reduce_factor=1.0, # test consistent when sel_reduce_factor == 1.0
9091
)
9192

9293
# dpa3 new impl
@@ -101,7 +102,7 @@ def test_consistency(
101102
seed=GLOBAL_SEED,
102103
).to(env.DEVICE)
103104

104-
repflow.no_sel = True
105+
repflow.use_dynamic_sel = True
105106

106107
# dpa3 new impl
107108
dd1 = DescrptDPA3(

0 commit comments

Comments
 (0)