Skip to content

Commit 0ca01d9

Browse files
authored
Merge branch 'devel' into perf-topk
2 parents c189b28 + 75b175b commit 0ca01d9

36 files changed

Lines changed: 54439 additions & 362 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repos:
2929
exclude: ^source/3rdparty
3030
- repo: https://github.com/astral-sh/ruff-pre-commit
3131
# Ruff version.
32-
rev: v0.11.9
32+
rev: v0.11.10
3333
hooks:
3434
- id: ruff
3535
args: ["--fix"]

deepmd/dpmodel/array_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,37 @@ def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray
9292
)
9393
else:
9494
raise NotImplementedError("Only JAX arrays are supported.")
95+
96+
97+
def xp_add_at(x, indices, values):
98+
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
99+
xp = array_api_compat.array_namespace(x, indices, values)
100+
if array_api_compat.is_numpy_array(x):
101+
# NumPy: supports np.add.at (in-place)
102+
xp.add.at(x, indices, values)
103+
return x
104+
105+
elif array_api_compat.is_jax_array(x):
106+
# JAX: functional update, not in-place
107+
return x.at[indices].add(values)
108+
else:
109+
# Fallback for array_api_strict: use basic indexing only
110+
# may need a more efficient way to do this
111+
n = indices.shape[0]
112+
for i in range(n):
113+
idx = int(indices[i])
114+
x[idx, ...] = x[idx, ...] + values[i, ...]
115+
return x
116+
117+
118+
def xp_bincount(x, weights=None, minlength=0):
119+
"""Counts the number of occurrences of each value in x."""
120+
xp = array_api_compat.array_namespace(x)
121+
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):
122+
result = xp.bincount(x, weights=weights, minlength=minlength)
123+
else:
124+
if weights is None:
125+
weights = xp.ones_like(x)
126+
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
127+
result = xp_add_at(result, x, weights)
128+
return result

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ class RepFlowArgs:
123123
smooth_edge_update : bool, optional
124124
Whether to make edge update smooth.
125125
If True, the edge update from angle message will not use self as padding.
126+
use_exp_switch : bool, optional
127+
Whether to use an exponential switch function instead of a polynomial one in the neighbor update.
128+
The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance
129+
`r` approaches the cutoff radius `rcut`. Specifically, the function is defined as:
130+
s(r) = \\exp(-\\exp(20 * (r - rcut_smth) / rcut_smth)) for 0 < r \\leq rcut, and s(r) = 0 for r > rcut.
131+
Here, `rcut_smth` is an adjustable smoothing factor and `rcut_smth` should be chosen carefully
132+
according to `rcut`, ensuring s(r) approaches zero smoothly at the cutoff.
133+
Typical recommended values are `rcut_smth` = 5.3 for `rcut` = 6.0, and 3.5 for `rcut` = 4.0.
134+
use_dynamic_sel : bool, optional
135+
Whether to dynamically select neighbors within the cutoff radius.
136+
If True, the exact number of neighbors within the cutoff radius is used
137+
without padding to a fixed selection numbers.
138+
When enabled, users can safely set larger values for `e_sel` or `a_sel` (e.g., 1200 or 300, respectively)
139+
to guarantee capturing all neighbors within the cutoff radius.
140+
Note that when using dynamic selection, the `smooth_edge_update` must be True.
141+
sel_reduce_factor : float, optional
142+
Reduction factor applied to neighbor-scale normalization when `use_dynamic_sel` is True.
143+
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
144+
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
145+
accommodating larger selection numbers.
126146
"""
127147

128148
def __init__(
@@ -150,6 +170,9 @@ def __init__(
150170
skip_stat: bool = False,
151171
optim_update: bool = True,
152172
smooth_edge_update: bool = False,
173+
use_exp_switch: bool = False,
174+
use_dynamic_sel: bool = False,
175+
sel_reduce_factor: float = 10.0,
153176
) -> None:
154177
self.n_dim = n_dim
155178
self.e_dim = e_dim
@@ -176,6 +199,9 @@ def __init__(
176199
self.a_compress_use_split = a_compress_use_split
177200
self.optim_update = optim_update
178201
self.smooth_edge_update = smooth_edge_update
202+
self.use_exp_switch = use_exp_switch
203+
self.use_dynamic_sel = use_dynamic_sel
204+
self.sel_reduce_factor = sel_reduce_factor
179205

180206
def __getitem__(self, key):
181207
if hasattr(self, key):
@@ -207,6 +233,9 @@ def serialize(self) -> dict:
207233
"fix_stat_std": self.fix_stat_std,
208234
"optim_update": self.optim_update,
209235
"smooth_edge_update": self.smooth_edge_update,
236+
"use_exp_switch": self.use_exp_switch,
237+
"use_dynamic_sel": self.use_dynamic_sel,
238+
"sel_reduce_factor": self.sel_reduce_factor,
210239
}
211240

212241
@classmethod
@@ -303,6 +332,9 @@ def init_subclass_params(sub_data, sub_class):
303332
fix_stat_std=self.repflow_args.fix_stat_std,
304333
optim_update=self.repflow_args.optim_update,
305334
smooth_edge_update=self.repflow_args.smooth_edge_update,
335+
use_exp_switch=self.repflow_args.use_exp_switch,
336+
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
337+
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
306338
exclude_types=exclude_types,
307339
env_protection=env_protection,
308340
precision=precision,

0 commit comments

Comments
 (0)