Skip to content

Commit ea3fc02

Browse files
authored
Merge branch 'devel' into D0520_dpa3_dist
Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
2 parents cff3a51 + 95ca4ad commit ea3fc02

24 files changed

Lines changed: 54220 additions & 193 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ repos:
2929
exclude: ^source/3rdparty
3030
- repo: https://github.com/astral-sh/ruff-pre-commit
3131
# Ruff version.
32-
rev: v0.11.9
32+
rev: v0.11.10
3333
hooks:
3434
- id: ruff
3535
args: ["--fix"]

deepmd/dpmodel/array_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,37 @@ def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray
9292
)
9393
else:
9494
raise NotImplementedError("Only JAX arrays are supported.")
95+
96+
97+
def xp_add_at(x, indices, values):
98+
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
99+
xp = array_api_compat.array_namespace(x, indices, values)
100+
if array_api_compat.is_numpy_array(x):
101+
# NumPy: supports np.add.at (in-place)
102+
xp.add.at(x, indices, values)
103+
return x
104+
105+
elif array_api_compat.is_jax_array(x):
106+
# JAX: functional update, not in-place
107+
return x.at[indices].add(values)
108+
else:
109+
# Fallback for array_api_strict: use basic indexing only
110+
# may need a more efficient way to do this
111+
n = indices.shape[0]
112+
for i in range(n):
113+
idx = int(indices[i])
114+
x[idx, ...] = x[idx, ...] + values[i, ...]
115+
return x
116+
117+
118+
def xp_bincount(x, weights=None, minlength=0):
119+
"""Counts the number of occurrences of each value in x."""
120+
xp = array_api_compat.array_namespace(x)
121+
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):
122+
result = xp.bincount(x, weights=weights, minlength=minlength)
123+
else:
124+
if weights is None:
125+
weights = xp.ones_like(x)
126+
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
127+
result = xp_add_at(result, x, weights)
128+
return result

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ class RepFlowArgs:
126126
edge_init_use_dist : bool, optional
127127
Whether to use direct distance r to initialize the edge features instead of 1/r.
128128
Note that when using this option, the activation function will not be used when initializing edge features.
129+
use_dynamic_sel : bool, optional
130+
Whether to dynamically select neighbors within the cutoff radius.
131+
If True, the exact number of neighbors within the cutoff radius is used
132+
without padding to a fixed selection numbers.
133+
When enabled, users can safely set larger values for `e_sel` or `a_sel` (e.g., 1200 or 300, respectively)
134+
to guarantee capturing all neighbors within the cutoff radius.
135+
Note that when using dynamic selection, the `smooth_edge_update` must be True.
136+
sel_reduce_factor : float, optional
137+
Reduction factor applied to neighbor-scale normalization when `use_dynamic_sel` is True.
138+
In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor`
139+
or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values,
140+
accommodating larger selection numbers.
129141
"""
130142

131143
def __init__(
@@ -154,6 +166,8 @@ def __init__(
154166
optim_update: bool = True,
155167
smooth_edge_update: bool = False,
156168
edge_init_use_dist: bool = False,
169+
use_dynamic_sel: bool = False,
170+
sel_reduce_factor: float = 10.0,
157171
) -> None:
158172
self.n_dim = n_dim
159173
self.e_dim = e_dim
@@ -181,6 +195,8 @@ def __init__(
181195
self.optim_update = optim_update
182196
self.smooth_edge_update = smooth_edge_update
183197
self.edge_init_use_dist = edge_init_use_dist
198+
self.use_dynamic_sel = use_dynamic_sel
199+
self.sel_reduce_factor = sel_reduce_factor
184200

185201
def __getitem__(self, key):
186202
if hasattr(self, key):
@@ -213,6 +229,8 @@ def serialize(self) -> dict:
213229
"optim_update": self.optim_update,
214230
"smooth_edge_update": self.smooth_edge_update,
215231
"edge_init_use_dist": self.edge_init_use_dist,
232+
"use_dynamic_sel": self.use_dynamic_sel,
233+
"sel_reduce_factor": self.sel_reduce_factor,
216234
}
217235

218236
@classmethod
@@ -310,6 +328,8 @@ def init_subclass_params(sub_data, sub_class):
310328
optim_update=self.repflow_args.optim_update,
311329
smooth_edge_update=self.repflow_args.smooth_edge_update,
312330
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
331+
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
332+
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
313333
exclude_types=exclude_types,
314334
env_protection=env_protection,
315335
precision=precision,

0 commit comments

Comments
 (0)