Skip to content

Commit 78a8e57

Browse files
feat(pt/dp): add distance init for DPA3 edge feat (#4760)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a configurable option to initialize edge features using direct distance (`r`) instead of reciprocal (`1/r`) across DPA3 descriptors. - The new parameter `edge_init_use_dist` is available in configuration schemas, function arguments, and test parameterizations, allowing flexible control over edge feature initialization. - **Bug Fixes** - Updated test logic to conditionally skip unsupported configurations when the new initialization option is enabled, ensuring robust testing across backends. - **Tests** - Enhanced test coverage to validate behavior with the new edge initialization method. - Expanded parameterization to include both `True` and `False` options for `edge_init_use_dist`, improving test comprehensiveness. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 75b175b commit 78a8e57

File tree

7 files changed

+63
-4
lines changed

7 files changed

+63
-4
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class RepFlowArgs:
123123
smooth_edge_update : bool, optional
124124
Whether to make edge update smooth.
125125
If True, the edge update from angle message will not use self as padding.
126+
edge_init_use_dist : bool, optional
127+
Whether to use direct distance r to initialize the edge features instead of 1/r.
128+
Note that when using this option, the activation function will not be used when initializing edge features.
126129
use_exp_switch : bool, optional
127130
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
128131
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
@@ -170,6 +173,7 @@ def __init__(
170173
skip_stat: bool = False,
171174
optim_update: bool = True,
172175
smooth_edge_update: bool = False,
176+
edge_init_use_dist: bool = False,
173177
use_exp_switch: bool = False,
174178
use_dynamic_sel: bool = False,
175179
sel_reduce_factor: float = 10.0,
@@ -199,6 +203,7 @@ def __init__(
199203
self.a_compress_use_split = a_compress_use_split
200204
self.optim_update = optim_update
201205
self.smooth_edge_update = smooth_edge_update
206+
self.edge_init_use_dist = edge_init_use_dist
202207
self.use_exp_switch = use_exp_switch
203208
self.use_dynamic_sel = use_dynamic_sel
204209
self.sel_reduce_factor = sel_reduce_factor
@@ -233,6 +238,7 @@ def serialize(self) -> dict:
233238
"fix_stat_std": self.fix_stat_std,
234239
"optim_update": self.optim_update,
235240
"smooth_edge_update": self.smooth_edge_update,
241+
"edge_init_use_dist": self.edge_init_use_dist,
236242
"use_exp_switch": self.use_exp_switch,
237243
"use_dynamic_sel": self.use_dynamic_sel,
238244
"sel_reduce_factor": self.sel_reduce_factor,
@@ -332,6 +338,7 @@ def init_subclass_params(sub_data, sub_class):
332338
fix_stat_std=self.repflow_args.fix_stat_std,
333339
optim_update=self.repflow_args.optim_update,
334340
smooth_edge_update=self.repflow_args.smooth_edge_update,
341+
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
335342
use_exp_switch=self.repflow_args.use_exp_switch,
336343
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
337344
sel_reduce_factor=self.repflow_args.sel_reduce_factor,

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
125125
smooth_edge_update : bool, optional
126126
Whether to make edge update smooth.
127127
If True, the edge update from angle message will not use self as padding.
128+
edge_init_use_dist : bool, optional
129+
Whether to use direct distance r to initialize the edge features instead of 1/r.
130+
Note that when using this option, the activation function will not be used when initializing edge features.
128131
use_exp_switch : bool, optional
129132
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
130133
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
@@ -193,6 +196,7 @@ def __init__(
193196
fix_stat_std: float = 0.3,
194197
optim_update: bool = True,
195198
smooth_edge_update: bool = False,
199+
edge_init_use_dist: bool = False,
196200
use_exp_switch: bool = False,
197201
use_dynamic_sel: bool = False,
198202
sel_reduce_factor: float = 10.0,
@@ -227,6 +231,7 @@ def __init__(
227231
self.a_compress_use_split = a_compress_use_split
228232
self.optim_update = optim_update
229233
self.smooth_edge_update = smooth_edge_update
234+
self.edge_init_use_dist = edge_init_use_dist
230235
self.use_exp_switch = use_exp_switch
231236
self.use_dynamic_sel = use_dynamic_sel
232237
self.sel_reduce_factor = sel_reduce_factor
@@ -510,7 +515,11 @@ def call(
510515
# get edge and angle embedding input
511516
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
512517
# edge_input, h2 = xp.split(dmatrix, [1], axis=-1)
513-
edge_input = dmatrix[:, :, :, :1]
518+
# nb x nloc x nnei x 1
519+
if self.edge_init_use_dist:
520+
edge_input = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
521+
else:
522+
edge_input = dmatrix[:, :, :, :1]
514523
h2 = dmatrix[:, :, :, 1:]
515524

516525
# nf x nloc x a_nnei x 3
@@ -552,7 +561,10 @@ def call(
552561

553562
# get edge and angle embedding
554563
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
555-
edge_ebd = self.act(self.edge_embd(edge_input))
564+
if not self.edge_init_use_dist:
565+
edge_ebd = self.act(self.edge_embd(edge_input))
566+
else:
567+
edge_ebd = self.edge_embd(edge_input)
556568
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
557569
angle_ebd = self.angle_embd(angle_input)
558570

@@ -663,6 +675,7 @@ def serialize(self):
663675
"precision": self.precision,
664676
"fix_stat_std": self.fix_stat_std,
665677
"optim_update": self.optim_update,
678+
"edge_init_use_dist": self.edge_init_use_dist,
666679
"use_exp_switch": self.use_exp_switch,
667680
"smooth_edge_update": self.smooth_edge_update,
668681
"use_dynamic_sel": self.use_dynamic_sel,

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class):
150150
fix_stat_std=self.repflow_args.fix_stat_std,
151151
optim_update=self.repflow_args.optim_update,
152152
smooth_edge_update=self.repflow_args.smooth_edge_update,
153+
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
153154
use_exp_switch=self.repflow_args.use_exp_switch,
154155
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
155156
sel_reduce_factor=self.repflow_args.sel_reduce_factor,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ class DescrptBlockRepflows(DescriptorBlock):
136136
smooth_edge_update : bool, optional
137137
Whether to make edge update smooth.
138138
If True, the edge update from angle message will not use self as padding.
139+
edge_init_use_dist : bool, optional
140+
Whether to use direct distance r to initialize the edge features instead of 1/r.
141+
Note that when using this option, the activation function will not be used when initializing edge features.
139142
use_exp_switch : bool, optional
140143
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
141144
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
@@ -206,6 +209,7 @@ def __init__(
206209
precision: str = "float64",
207210
fix_stat_std: float = 0.3,
208211
smooth_edge_update: bool = False,
212+
edge_init_use_dist: bool = False,
209213
use_exp_switch: bool = False,
210214
use_dynamic_sel: bool = False,
211215
sel_reduce_factor: float = 10.0,
@@ -241,6 +245,7 @@ def __init__(
241245
self.a_compress_use_split = a_compress_use_split
242246
self.optim_update = optim_update
243247
self.smooth_edge_update = smooth_edge_update
248+
self.edge_init_use_dist = edge_init_use_dist
244249
self.use_exp_switch = use_exp_switch
245250
self.use_dynamic_sel = use_dynamic_sel
246251
self.sel_reduce_factor = sel_reduce_factor
@@ -483,6 +488,10 @@ def forward(
483488
# get edge and angle embedding input
484489
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
485490
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
491+
if self.edge_init_use_dist:
492+
# nb x nloc x nnei x 1
493+
edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True)
494+
486495
# nf x nloc x a_nnei x 3
487496
normalized_diff_i = a_diff / (
488497
torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6
@@ -519,7 +528,10 @@ def forward(
519528
)
520529
# get edge and angle embedding
521530
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
522-
edge_ebd = self.act(self.edge_embd(edge_input))
531+
if not self.edge_init_use_dist:
532+
edge_ebd = self.act(self.edge_embd(edge_input))
533+
else:
534+
edge_ebd = self.edge_embd(edge_input)
523535
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
524536
angle_ebd = self.angle_embd(angle_input)
525537

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,10 @@ def dpa3_repflow_args():
14971497
"Whether to make edge update smooth. "
14981498
"If True, the edge update from angle message will not use self as padding."
14991499
)
1500+
doc_edge_init_use_dist = (
1501+
"Whether to use direct distance r to initialize the edge features instead of 1/r. "
1502+
"Note that when using this option, the activation function will not be used when initializing edge features."
1503+
)
15001504
doc_use_exp_switch = (
15011505
"Whether to use an exponential switch function instead of a polynomial one in the neighbor update. "
15021506
"The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance "
@@ -1620,6 +1624,14 @@ def dpa3_repflow_args():
16201624
default=False, # For compatability. This will be True in the future
16211625
doc=doc_smooth_edge_update,
16221626
),
1627+
Argument(
1628+
"edge_init_use_dist",
1629+
bool,
1630+
optional=True,
1631+
default=False,
1632+
alias=["edge_use_dist"],
1633+
doc=doc_edge_init_use_dist,
1634+
),
16231635
Argument(
16241636
"use_exp_switch",
16251637
bool,

source/tests/consistent/descriptor/test_dpa3.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
(1, 2), # a_compress_e_rate
6666
(True,), # a_compress_use_split
6767
(True, False), # optim_update
68+
(True, False), # edge_init_use_dist
6869
(True, False), # use_exp_switch
6970
(True, False), # use_dynamic_sel
7071
(0.3, 0.0), # fix_stat_std
@@ -82,6 +83,7 @@ def data(self) -> dict:
8283
a_compress_e_rate,
8384
a_compress_use_split,
8485
optim_update,
86+
edge_init_use_dist,
8587
use_exp_switch,
8688
use_dynamic_sel,
8789
fix_stat_std,
@@ -107,6 +109,7 @@ def data(self) -> dict:
107109
"a_compress_e_rate": a_compress_e_rate,
108110
"a_compress_use_split": a_compress_use_split,
109111
"optim_update": optim_update,
112+
"edge_init_use_dist": edge_init_use_dist,
110113
"use_exp_switch": use_exp_switch,
111114
"use_dynamic_sel": use_dynamic_sel,
112115
"smooth_edge_update": True,
@@ -137,6 +140,7 @@ def skip_pt(self) -> bool:
137140
a_compress_e_rate,
138141
a_compress_use_split,
139142
optim_update,
143+
edge_init_use_dist,
140144
use_exp_switch,
141145
use_dynamic_sel,
142146
fix_stat_std,
@@ -155,6 +159,7 @@ def skip_pd(self) -> bool:
155159
a_compress_e_rate,
156160
a_compress_use_split,
157161
optim_update,
162+
edge_init_use_dist,
158163
use_exp_switch,
159164
use_dynamic_sel,
160165
fix_stat_std,
@@ -164,6 +169,7 @@ def skip_pd(self) -> bool:
164169
return (
165170
not INSTALLED_PD
166171
or precision == "bfloat16"
172+
or edge_init_use_dist
167173
or use_exp_switch
168174
or use_dynamic_sel
169175
) # not supported yet
@@ -178,6 +184,7 @@ def skip_dp(self) -> bool:
178184
a_compress_e_rate,
179185
a_compress_use_split,
180186
optim_update,
187+
edge_init_use_dist,
181188
use_exp_switch,
182189
use_dynamic_sel,
183190
fix_stat_std,
@@ -196,6 +203,7 @@ def skip_tf(self) -> bool:
196203
a_compress_e_rate,
197204
a_compress_use_split,
198205
optim_update,
206+
edge_init_use_dist,
199207
use_exp_switch,
200208
use_dynamic_sel,
201209
fix_stat_std,
@@ -256,6 +264,7 @@ def setUp(self) -> None:
256264
a_compress_e_rate,
257265
a_compress_use_split,
258266
optim_update,
267+
edge_init_use_dist,
259268
use_exp_switch,
260269
use_dynamic_sel,
261270
fix_stat_std,
@@ -337,6 +346,7 @@ def rtol(self) -> float:
337346
a_compress_e_rate,
338347
a_compress_use_split,
339348
optim_update,
349+
edge_init_use_dist,
340350
use_exp_switch,
341351
use_dynamic_sel,
342352
fix_stat_std,
@@ -361,6 +371,7 @@ def atol(self) -> float:
361371
a_compress_e_rate,
362372
a_compress_use_split,
363373
optim_update,
374+
edge_init_use_dist,
364375
use_exp_switch,
365376
use_dynamic_sel,
366377
fix_stat_std,

source/tests/universal/dpmodel/descriptor/test_descriptor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def DescriptorParamDPA3(
482482
a_compress_use_split=False,
483483
optim_update=True,
484484
smooth_edge_update=False,
485+
edge_init_use_dist=False,
485486
use_exp_switch=False,
486487
fix_stat_std=0.3,
487488
use_dynamic_sel=False,
@@ -511,6 +512,7 @@ def DescriptorParamDPA3(
511512
"optim_update": optim_update,
512513
"use_exp_switch": use_exp_switch,
513514
"smooth_edge_update": smooth_edge_update,
515+
"edge_init_use_dist": edge_init_use_dist,
514516
"fix_stat_std": fix_stat_std,
515517
"n_multi_edge_message": n_multi_edge_message,
516518
"axis_neuron": 2,
@@ -549,9 +551,10 @@ def DescriptorParamDPA3(
549551
"a_compress_use_split": (True,),
550552
"optim_update": (True, False),
551553
"smooth_edge_update": (True,),
554+
"edge_init_use_dist": (True, False),
552555
"use_exp_switch": (True, False),
553556
"fix_stat_std": (0.3,),
554-
"n_multi_edge_message": (1, 2),
557+
"n_multi_edge_message": (1,),
555558
"use_dynamic_sel": (True, False),
556559
"env_protection": (0.0, 1e-8),
557560
"precision": ("float64",),

0 commit comments

Comments
 (0)