Skip to content

Commit 75b175b

Browse files
iProzdnjzjzpre-commit-ci[bot]
authored
feat(pt/dp): add exponential switch function (#4756)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added an option to use an exponential switch function for neighbor updates in descriptors, providing a smoother decay of neighbor contributions near cutoff distances. This is enabled via a new parameter across multiple descriptor components. - **Documentation** - Updated parameter documentation to describe the exponential switch function, its mathematical definition, and recommended smoothing parameters. - **Tests** - Enhanced test suites to cover the new exponential switch option, verifying consistent behavior and integration across different configurations. - **Improvements** - Included environment protection parameter in serialization of environment matrices across several descriptor classes for more complete state representation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 95ca4ad commit 75b175b

File tree

15 files changed

+150
-12
lines changed

15 files changed

+150
-12
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ class RepFlowArgs:
123123
smooth_edge_update : bool, optional
124124
Whether to make edge update smooth.
125125
If True, the edge update from angle message will not use self as padding.
126+
use_exp_switch : bool, optional
127+
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
128+
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
129+
`r` approaches the cutoff radius `rcut`. Specifically, the function is defined as:
130+
s(r) = \\exp(-\\exp(20 * (r - rcut_smth) / rcut_smth)) for 0 < r \\leq rcut, and s(r) = 0 for r > rcut.
131+
Here, `rcut_smth` is an adjustable smoothing factor and `rcut_smth` should be chosen carefully
132+
according to `rcut`, ensuring s(r) approaches zero smoothly at the cutoff.
133+
Typical recommended values are `rcut_smth` = 5.3 for `rcut` = 6.0, and 3.5 for `rcut` = 4.0.
126134
use_dynamic_sel : bool, optional
127135
Whether to dynamically select neighbors within the cutoff radius.
128136
If True, the exact number of neighbors within the cutoff radius is used
@@ -162,6 +170,7 @@ def __init__(
162170
skip_stat: bool = False,
163171
optim_update: bool = True,
164172
smooth_edge_update: bool = False,
173+
use_exp_switch: bool = False,
165174
use_dynamic_sel: bool = False,
166175
sel_reduce_factor: float = 10.0,
167176
) -> None:
@@ -190,6 +199,7 @@ def __init__(
190199
self.a_compress_use_split = a_compress_use_split
191200
self.optim_update = optim_update
192201
self.smooth_edge_update = smooth_edge_update
202+
self.use_exp_switch = use_exp_switch
193203
self.use_dynamic_sel = use_dynamic_sel
194204
self.sel_reduce_factor = sel_reduce_factor
195205

@@ -223,6 +233,7 @@ def serialize(self) -> dict:
223233
"fix_stat_std": self.fix_stat_std,
224234
"optim_update": self.optim_update,
225235
"smooth_edge_update": self.smooth_edge_update,
236+
"use_exp_switch": self.use_exp_switch,
226237
"use_dynamic_sel": self.use_dynamic_sel,
227238
"sel_reduce_factor": self.sel_reduce_factor,
228239
}
@@ -321,6 +332,7 @@ def init_subclass_params(sub_data, sub_class):
321332
fix_stat_std=self.repflow_args.fix_stat_std,
322333
optim_update=self.repflow_args.optim_update,
323334
smooth_edge_update=self.repflow_args.smooth_edge_update,
335+
use_exp_switch=self.repflow_args.use_exp_switch,
324336
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
325337
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
326338
exclude_types=exclude_types,

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
125125
smooth_edge_update : bool, optional
126126
Whether to make edge update smooth.
127127
If True, the edge update from angle message will not use self as padding.
128+
use_exp_switch : bool, optional
129+
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
130+
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
131+
`r` approaches the cutoff radius `rcut`. Specifically, the function is defined as:
132+
s(r) = \\exp(-\\exp(20 * (r - rcut_smth) / rcut_smth)) for 0 < r \\leq rcut, and s(r) = 0 for r > rcut.
133+
Here, `rcut_smth` is an adjustable smoothing factor and `rcut_smth` should be chosen carefully
134+
according to `rcut`, ensuring s(r) approaches zero smoothly at the cutoff.
135+
Typical recommended values are `rcut_smth` = 5.3 for `rcut` = 6.0, and 3.5 for `rcut` = 4.0.
128136
use_dynamic_sel : bool, optional
129137
Whether to dynamically select neighbors within the cutoff radius.
130138
If True, the exact number of neighbors within the cutoff radius is used
@@ -185,6 +193,7 @@ def __init__(
185193
fix_stat_std: float = 0.3,
186194
optim_update: bool = True,
187195
smooth_edge_update: bool = False,
196+
use_exp_switch: bool = False,
188197
use_dynamic_sel: bool = False,
189198
sel_reduce_factor: float = 10.0,
190199
seed: Optional[Union[int, list[int]]] = None,
@@ -218,6 +227,7 @@ def __init__(
218227
self.a_compress_use_split = a_compress_use_split
219228
self.optim_update = optim_update
220229
self.smooth_edge_update = smooth_edge_update
230+
self.use_exp_switch = use_exp_switch
221231
self.use_dynamic_sel = use_dynamic_sel
222232
self.sel_reduce_factor = sel_reduce_factor
223233
if self.use_dynamic_sel and not self.smooth_edge_update:
@@ -290,10 +300,16 @@ def __init__(
290300

291301
wanted_shape = (self.ntypes, self.nnei, 4)
292302
self.env_mat_edge = EnvMat(
293-
self.e_rcut, self.e_rcut_smth, protection=self.env_protection
303+
self.e_rcut,
304+
self.e_rcut_smth,
305+
protection=self.env_protection,
306+
use_exp_switch=self.use_exp_switch,
294307
)
295308
self.env_mat_angle = EnvMat(
296-
self.a_rcut, self.a_rcut_smth, protection=self.env_protection
309+
self.a_rcut,
310+
self.a_rcut_smth,
311+
protection=self.env_protection,
312+
use_exp_switch=self.use_exp_switch,
297313
)
298314
self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision])
299315
self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision])
@@ -647,6 +663,7 @@ def serialize(self):
647663
"precision": self.precision,
648664
"fix_stat_std": self.fix_stat_std,
649665
"optim_update": self.optim_update,
666+
"use_exp_switch": self.use_exp_switch,
650667
"smooth_edge_update": self.smooth_edge_update,
651668
"use_dynamic_sel": self.use_dynamic_sel,
652669
"sel_reduce_factor": self.sel_reduce_factor,

deepmd/dpmodel/utils/env_mat.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,32 @@ def compute_smooth_weight(
3535
return vv
3636

3737

38+
@support_array_api(version="2023.12")
39+
def compute_exp_sw(
40+
distance: np.ndarray,
41+
rmin: float,
42+
rmax: float,
43+
):
44+
"""Compute the exponential switch function for neighbor update."""
45+
if rmin >= rmax:
46+
raise ValueError("rmin should be less than rmax.")
47+
xp = array_api_compat.array_namespace(distance)
48+
distance = xp.clip(distance, min=0.0, max=rmax)
49+
C = 20
50+
a = C / rmin
51+
b = rmin
52+
exp_sw = xp.exp(-xp.exp(a * (distance - b)))
53+
return exp_sw
54+
55+
3856
def _make_env_mat(
3957
nlist,
4058
coord,
4159
rcut: float,
4260
ruct_smth: float,
4361
radial_only: bool = False,
4462
protection: float = 0.0,
63+
use_exp_switch: bool = False,
4564
):
4665
"""Make smooth environment matrix."""
4766
xp = array_api_compat.array_namespace(nlist)
@@ -66,7 +85,11 @@ def _make_env_mat(
6685
length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype)
6786
t0 = 1 / (length + protection)
6887
t1 = diff / (length + protection) ** 2
69-
weight = compute_smooth_weight(length, ruct_smth, rcut)
88+
weight = (
89+
compute_smooth_weight(length, ruct_smth, rcut)
90+
if not use_exp_switch
91+
else compute_exp_sw(length, ruct_smth, rcut)
92+
)
7093
weight = weight * xp.astype(xp.expand_dims(mask, axis=-1), weight.dtype)
7194
if radial_only:
7295
env_mat = t0 * weight
@@ -81,10 +104,12 @@ def __init__(
81104
rcut,
82105
rcut_smth,
83106
protection: float = 0.0,
107+
use_exp_switch: bool = False,
84108
) -> None:
85109
self.rcut = rcut
86110
self.rcut_smth = rcut_smth
87111
self.protection = protection
112+
self.use_exp_switch = use_exp_switch
88113

89114
def call(
90115
self,
@@ -142,6 +167,7 @@ def _call(self, nlist, coord_ext, radial_only):
142167
self.rcut_smth,
143168
radial_only=radial_only,
144169
protection=self.protection,
170+
use_exp_switch=self.use_exp_switch,
145171
)
146172
return em, diff, ww
147173

@@ -151,6 +177,8 @@ def serialize(
151177
return {
152178
"rcut": self.rcut,
153179
"rcut_smth": self.rcut_smth,
180+
"protection": self.protection,
181+
"use_exp_switch": self.use_exp_switch,
154182
}
155183

156184
@classmethod

deepmd/pd/model/descriptor/dpa1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,9 @@ def serialize(self) -> dict:
506506
"precision": RESERVED_PRECISION_DICT[obj.prec],
507507
"embeddings": obj.filter_layers.serialize(),
508508
"attention_layers": obj.dpa1_attention.serialize(),
509-
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
509+
"env_mat": DPEnvMat(
510+
obj.rcut, obj.rcut_smth, obj.env_protection
511+
).serialize(),
510512
"type_embedding": self.type_embedding.embedding.serialize(),
511513
"exclude_types": obj.exclude_types,
512514
"env_protection": obj.env_protection,

deepmd/pd/model/descriptor/se_a.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ def serialize(self) -> dict:
337337
# make deterministic
338338
"precision": RESERVED_PRECISION_DICT[obj.prec],
339339
"embeddings": obj.filter_layers.serialize(),
340-
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
340+
"env_mat": DPEnvMat(
341+
obj.rcut, obj.rcut_smth, obj.env_protection
342+
).serialize(),
341343
"exclude_types": obj.exclude_types,
342344
"env_protection": obj.env_protection,
343345
"@variables": {

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,9 @@ def serialize(self) -> dict:
508508
"precision": RESERVED_PRECISION_DICT[obj.prec],
509509
"embeddings": obj.filter_layers.serialize(),
510510
"attention_layers": obj.dpa1_attention.serialize(),
511-
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
511+
"env_mat": DPEnvMat(
512+
obj.rcut, obj.rcut_smth, obj.env_protection
513+
).serialize(),
512514
"type_embedding": self.type_embedding.embedding.serialize(),
513515
"exclude_types": obj.exclude_types,
514516
"env_protection": obj.env_protection,

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class):
150150
fix_stat_std=self.repflow_args.fix_stat_std,
151151
optim_update=self.repflow_args.optim_update,
152152
smooth_edge_update=self.repflow_args.smooth_edge_update,
153+
use_exp_switch=self.repflow_args.use_exp_switch,
153154
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
154155
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
155156
exclude_types=exclude_types,

deepmd/pt/model/descriptor/env_mat.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from deepmd.pt.utils.preprocess import (
6+
compute_exp_sw,
67
compute_smooth_weight,
78
)
89

@@ -14,6 +15,7 @@ def _make_env_mat(
1415
ruct_smth: float,
1516
radial_only: bool = False,
1617
protection: float = 0.0,
18+
use_exp_switch: bool = False,
1719
):
1820
"""Make smooth environment matrix."""
1921
bsz, natoms, nnei = nlist.shape
@@ -33,7 +35,11 @@ def _make_env_mat(
3335
length = length + ~mask.unsqueeze(-1)
3436
t0 = 1 / (length + protection)
3537
t1 = diff / (length + protection) ** 2
36-
weight = compute_smooth_weight(length, ruct_smth, rcut)
38+
weight = (
39+
compute_smooth_weight(length, ruct_smth, rcut)
40+
if not use_exp_switch
41+
else compute_exp_sw(length, ruct_smth, rcut)
42+
)
3743
weight = weight * mask.unsqueeze(-1)
3844
if radial_only:
3945
env_mat = t0 * weight
@@ -52,6 +58,7 @@ def prod_env_mat(
5258
rcut_smth: float,
5359
radial_only: bool = False,
5460
protection: float = 0.0,
61+
use_exp_switch: bool = False,
5562
):
5663
"""Generate smooth environment matrix from atom coordinates and other context.
5764
@@ -64,6 +71,7 @@ def prod_env_mat(
6471
- rcut_smth: Smooth hyper-parameter for pair force & energy.
6572
- radial_only: Whether to return a full description or a radial-only descriptor.
6673
- protection: Protection parameter to prevent division by zero errors during calculations.
74+
- use_exp_switch: Whether to use the exponential switch function.
6775
6876
Returns
6977
-------
@@ -76,6 +84,7 @@ def prod_env_mat(
7684
rcut_smth,
7785
radial_only,
7886
protection=protection,
87+
use_exp_switch=use_exp_switch,
7988
) # shape [n_atom, dim, 4 or 1]
8089
t_avg = mean[atype] # [n_atom, dim, 4 or 1]
8190
t_std = stddev[atype] # [n_atom, dim, 4 or 1]

deepmd/pt/model/descriptor/repflows.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ class DescrptBlockRepflows(DescriptorBlock):
136136
smooth_edge_update : bool, optional
137137
Whether to make edge update smooth.
138138
If True, the edge update from angle message will not use self as padding.
139+
use_exp_switch : bool, optional
140+
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
141+
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
142+
`r` approaches the cutoff radius `rcut`. Specifically, the function is defined as:
143+
s(r) = \\exp(-\\exp(20 * (r - rcut_smth) / rcut_smth)) for 0 < r \\leq rcut, and s(r) = 0 for r > rcut.
144+
Here, `rcut_smth` is an adjustable smoothing factor and `rcut_smth` should be chosen carefully
145+
according to `rcut`, ensuring s(r) approaches zero smoothly at the cutoff.
146+
Typical recommended values are `rcut_smth` = 5.3 for `rcut` = 6.0, and 3.5 for `rcut` = 4.0.
139147
use_dynamic_sel : bool, optional
140148
Whether to dynamically select neighbors within the cutoff radius.
141149
If True, the exact number of neighbors within the cutoff radius is used
@@ -198,6 +206,7 @@ def __init__(
198206
precision: str = "float64",
199207
fix_stat_std: float = 0.3,
200208
smooth_edge_update: bool = False,
209+
use_exp_switch: bool = False,
201210
use_dynamic_sel: bool = False,
202211
sel_reduce_factor: float = 10.0,
203212
optim_update: bool = True,
@@ -232,6 +241,7 @@ def __init__(
232241
self.a_compress_use_split = a_compress_use_split
233242
self.optim_update = optim_update
234243
self.smooth_edge_update = smooth_edge_update
244+
self.use_exp_switch = use_exp_switch
235245
self.use_dynamic_sel = use_dynamic_sel
236246
self.sel_reduce_factor = sel_reduce_factor
237247
if self.use_dynamic_sel and not self.smooth_edge_update:
@@ -425,6 +435,7 @@ def forward(
425435
self.e_rcut,
426436
self.e_rcut_smth,
427437
protection=self.env_protection,
438+
use_exp_switch=self.use_exp_switch,
428439
)
429440
nlist_mask = nlist != -1
430441
sw = torch.squeeze(sw, -1)
@@ -446,6 +457,7 @@ def forward(
446457
self.a_rcut,
447458
self.a_rcut_smth,
448459
protection=self.env_protection,
460+
use_exp_switch=self.use_exp_switch,
449461
)
450462
a_nlist_mask = a_nlist != -1
451463
a_sw = torch.squeeze(a_sw, -1)

deepmd/pt/model/descriptor/se_a.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,9 @@ def serialize(self) -> dict:
381381
# make deterministic
382382
"precision": RESERVED_PRECISION_DICT[obj.prec],
383383
"embeddings": obj.filter_layers.serialize(),
384-
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
384+
"env_mat": DPEnvMat(
385+
obj.rcut, obj.rcut_smth, obj.env_protection
386+
).serialize(),
385387
"exclude_types": obj.exclude_types,
386388
"env_protection": obj.env_protection,
387389
"@variables": {

0 commit comments

Comments
 (0)