Skip to content

Commit fbabf9d

Browse files
authored
Merge branch 'master' into 0304_add_chg_spin
Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
2 parents a6c1dc2 + 4a29836 commit fbabf9d

100 files changed

Lines changed: 6201 additions & 209 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/backend/pt_expt.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
class PyTorchExportableBackend(Backend):
3333
"""PyTorch exportable backend."""
3434

35-
name = "PyTorch Exportable"
35+
name = "PyTorch-Exportable"
3636
"""The formal name of the backend."""
3737
features: ClassVar[Backend.Feature] = (
3838
Backend.Feature.ENTRY_POINT
@@ -63,7 +63,7 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
6363
Callable[[Namespace], None]
6464
The entry point hook of the backend.
6565
"""
66-
from deepmd.pt.entrypoints.main import main as deepmd_main
66+
from deepmd.pt_expt.entrypoints.main import main as deepmd_main
6767

6868
return deepmd_main
6969

@@ -76,7 +76,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]:
7676
type[DeepEvalBackend]
7777
The Deep Eval backend of the backend.
7878
"""
79-
raise NotImplementedError
79+
from deepmd.pt_expt.infer.deep_eval import (
80+
DeepEval,
81+
)
82+
83+
return DeepEval
8084

8185
@property
8286
def neighbor_stat(self) -> type["NeighborStat"]:
@@ -87,7 +91,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
8791
type[NeighborStat]
8892
The neighbor statistics of the backend.
8993
"""
90-
raise NotImplementedError
94+
from deepmd.pt_expt.utils.neighbor_stat import (
95+
NeighborStat,
96+
)
97+
98+
return NeighborStat
9199

92100
@property
93101
def serialize_hook(self) -> Callable[[str], dict]:
@@ -98,7 +106,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
98106
Callable[[str], dict]
99107
The serialize hook of the backend.
100108
"""
101-
raise NotImplementedError
109+
from deepmd.pt_expt.utils.serialization import (
110+
serialize_from_file,
111+
)
112+
113+
return serialize_from_file
102114

103115
@property
104116
def deserialize_hook(self) -> Callable[[str, dict], None]:
@@ -109,4 +121,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
109121
Callable[[str, dict], None]
110122
The deserialize hook of the backend.
111123
"""
112-
raise NotImplementedError
124+
from deepmd.pt_expt.utils.serialization import (
125+
deserialize_to_file,
126+
)
127+
128+
return deserialize_to_file

deepmd/dpmodel/array_api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
3232
# torch.take_along_dim requires int64 indices
3333
if array_api_compat.is_torch_array(indices):
3434
indices = xp.astype(indices, xp.int64)
35+
if array_api_compat.is_torch_array(arr):
36+
# Use torch.gather directly for torch.export dynamic shape compatibility.
37+
# array_api_compat's take_along_axis / torch.take_along_dim specializes
38+
# the source dimension size to a constant during torch.export tracing,
39+
# breaking dynamic shape export. torch.gather is the underlying
40+
# primitive and handles symbolic shapes correctly.
41+
import torch
42+
43+
return torch.gather(arr, axis, indices)
3544
if Version(xp.__array_api_version__) >= Version("2024.12"):
3645
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
3746
return xp.take_along_axis(arr, indices, axis=axis)
@@ -62,6 +71,24 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
6271
return xp_swapaxes(out, axis, -1)
6372

6473

74+
def xp_take_first_n(arr: Array, dim: int, n: int) -> Array:
75+
"""Take the first *n* elements along *dim*.
76+
77+
For torch tensors, uses ``torch.index_select`` so that
78+
``torch.export`` does not emit a contiguity guard that would
79+
prevent the ``nall == nloc`` (no-PBC) case from working.
80+
For numpy / jax, uses regular slicing.
81+
"""
82+
if array_api_compat.is_torch_array(arr):
83+
import torch
84+
85+
indices = torch.arange(n, dtype=torch.int64, device=arr.device)
86+
return torch.index_select(arr, dim, indices)
87+
slices = [slice(None)] * arr.ndim
88+
slices[dim] = slice(0, n)
89+
return arr[tuple(slices)]
90+
91+
6592
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
6693
"""Reduces all values from the src tensor to the indices specified in the index tensor.
6794

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from deepmd.dpmodel.array_api import (
1515
Array,
16+
xp_take_first_n,
1617
)
1718
from deepmd.dpmodel.common import (
1819
NativeOP,
@@ -62,6 +63,45 @@ def __init__(
6263
self.rcond = rcond
6364
self.preset_out_bias = preset_out_bias
6465
self.data_stat_protect = data_stat_protect
66+
self._observed_type: list[str] | None = None
67+
68+
@property
69+
def observed_type(self) -> list[str] | None:
70+
"""Get the observed element type list from data statistics."""
71+
return self._observed_type
72+
73+
def _collect_and_set_observed_type(
74+
self,
75+
sampled_func: Callable[[], list[dict]],
76+
stat_file_path: DPPath | None,
77+
preset_observed_type: list[str] | None,
78+
) -> None:
79+
"""Collect observed types with priority: preset > stat_file > compute.
80+
81+
Parameters
82+
----------
83+
sampled_func
84+
The lazy sampled function to get data frames.
85+
stat_file_path
86+
The path to the statistics files (should already include type_map suffix).
87+
preset_observed_type
88+
User-specified observed types that take highest priority.
89+
"""
90+
from deepmd.dpmodel.utils.stat import (
91+
_restore_observed_type_from_file,
92+
_save_observed_type_to_file,
93+
collect_observed_types,
94+
)
95+
96+
if preset_observed_type is not None:
97+
self._observed_type = preset_observed_type
98+
else:
99+
observed = _restore_observed_type_from_file(stat_file_path)
100+
if observed is None:
101+
sampled = sampled_func()
102+
observed = collect_observed_types(sampled, self.type_map)
103+
_save_observed_type_to_file(stat_file_path, observed)
104+
self._observed_type = observed
65105

66106
def init_out_stat(self) -> None:
67107
"""Initialize the output bias."""
@@ -211,7 +251,7 @@ def forward_common_atomic(
211251
"""
212252
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
213253
_, nloc, _ = nlist.shape
214-
atype = extended_atype[:, :nloc]
254+
atype = xp_take_first_n(extended_atype, 1, nloc)
215255
if self.pair_excl is not None:
216256
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
217257
# exclude neighbors in the nlist
@@ -229,7 +269,7 @@ def forward_common_atomic(
229269
ret_dict = self.apply_out_stat(ret_dict, atype)
230270

231271
# nf x nloc
232-
atom_mask = ext_atom_mask[:, :nloc]
272+
atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc)
233273
if self.atom_excl is not None:
234274
atom_mask = xp.logical_and(
235275
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
@@ -271,6 +311,29 @@ def get_compute_stats_distinguish_types(self) -> bool:
271311
"""Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
272312
return True
273313

314+
def compute_or_load_stat(
315+
self,
316+
sampled_func: Callable[[], list[dict]],
317+
stat_file_path: DPPath | None = None,
318+
compute_or_load_out_stat: bool = True,
319+
preset_observed_type: list[str] | None = None,
320+
) -> None:
321+
"""Compute or load the statistics parameters of the model,
322+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
323+
324+
Parameters
325+
----------
326+
sampled_func
327+
The lazy sampled function to get data frames from different data systems.
328+
stat_file_path
329+
The path to the stat file.
330+
compute_or_load_out_stat : bool
331+
Whether to compute the output statistics.
332+
If False, it will only compute the input statistics
333+
(e.g. mean and standard deviation of descriptors).
334+
"""
335+
raise NotImplementedError
336+
274337
def compute_or_load_out_stat(
275338
self,
276339
merged: Callable[[], list[dict]] | list[dict],
@@ -332,19 +395,19 @@ def wrapped_sampler() -> list[dict]:
332395
atom_exclude_types = self.atom_excl.get_exclude_types()
333396
for sample in sampled:
334397
sample["atom_exclude_types"] = list(atom_exclude_types)
335-
if (
336-
"find_fparam" not in sampled[0]
337-
and "fparam" not in sampled[0]
338-
and self.has_default_fparam()
339-
):
398+
# For systems where fparam is missing (find_fparam == 0),
399+
# fill with default fparam if available and mark as found.
400+
if self.has_default_fparam():
340401
default_fparam = self.get_default_fparam()
341402
if default_fparam is not None:
342403
default_fparam_np = np.array(default_fparam)
343404
for sample in sampled:
344-
nframe = sample["atype"].shape[0]
345-
sample["fparam"] = np.tile(
346-
default_fparam_np.reshape(1, -1), (nframe, 1)
347-
)
405+
if "find_fparam" in sample and not sample["find_fparam"]:
406+
nframe = sample["atype"].shape[0]
407+
sample["fparam"] = np.tile(
408+
default_fparam_np.reshape(1, -1), (nframe, 1)
409+
)
410+
sample["find_fparam"] = np.bool_(True)
348411
return sampled
349412

350413
return wrapped_sampler

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from deepmd.dpmodel.array_api import (
1010
Array,
11+
xp_take_first_n,
1112
)
1213
from deepmd.dpmodel.descriptor.base_descriptor import (
1314
BaseDescriptor,
@@ -181,7 +182,7 @@ def forward_atomic(
181182
182183
"""
183184
nframes, nloc, nnei = nlist.shape
184-
atype = extended_atype[:, :nloc]
185+
atype = xp_take_first_n(extended_atype, 1, nloc)
185186

186187
if self.fitting_net.get_dim_fparam() > 0 and fparam is None:
187188
# use default fparam
@@ -221,6 +222,7 @@ def compute_or_load_stat(
221222
sampled_func: Callable[[], list[dict]],
222223
stat_file_path: DPPath | None = None,
223224
compute_or_load_out_stat: bool = True,
225+
preset_observed_type: list[str] | None = None,
224226
) -> None:
225227
"""Compute or load the statistics parameters of the model,
226228
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
@@ -247,6 +249,10 @@ def compute_or_load_stat(
247249
if compute_or_load_out_stat:
248250
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
249251

252+
self._collect_and_set_observed_type(
253+
wrapped_sampler, stat_file_path, preset_observed_type
254+
)
255+
250256
def change_type_map(
251257
self, type_map: list[str], model_with_new_type_stat: Any | None = None
252258
) -> None:

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def compute_or_load_stat(
349349
sampled_func: Callable[[], list[dict]],
350350
stat_file_path: DPPath | None = None,
351351
compute_or_load_out_stat: bool = True,
352+
preset_observed_type: list[str] | None = None,
352353
) -> None:
353354
"""Compute or load the statistics parameters of the model.
354355
@@ -364,9 +365,21 @@ def compute_or_load_stat(
364365
compute_or_load_out_stat : bool
365366
Whether to compute the output statistics.
366367
"""
368+
# Compute observed type once at parent level, then propagate to
369+
# sub-models via preset_observed_type to avoid redundant computation.
370+
obs_stat_path = stat_file_path
371+
if obs_stat_path is not None and self.type_map is not None:
372+
obs_stat_path = obs_stat_path / " ".join(self.type_map)
373+
self._collect_and_set_observed_type(
374+
sampled_func, obs_stat_path, preset_observed_type
375+
)
376+
367377
for md in self.models:
368378
md.compute_or_load_stat(
369-
sampled_func, stat_file_path, compute_or_load_out_stat=False
379+
sampled_func,
380+
stat_file_path,
381+
compute_or_load_out_stat=False,
382+
preset_observed_type=self._observed_type,
370383
)
371384

372385
if stat_file_path is not None and self.type_map is not None:

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def compute_or_load_stat(
216216
sampled_func: Callable[[], list[dict]],
217217
stat_file_path: DPPath | None = None,
218218
compute_or_load_out_stat: bool = True,
219+
preset_observed_type: list[str] | None = None,
219220
) -> None:
220221
"""Compute or load the statistics parameters of the model.
221222
@@ -235,6 +236,15 @@ def compute_or_load_stat(
235236
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
236237
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
237238

239+
if stat_file_path is not None and self.type_map is not None:
240+
stat_file_path /= " ".join(self.type_map)
241+
242+
self._collect_and_set_observed_type(
243+
sampled_func if callable(sampled_func) else lambda: sampled_func,
244+
stat_file_path,
245+
preset_observed_type,
246+
)
247+
238248
def forward_atomic(
239249
self,
240250
extended_coord: Array,

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
242242
arXiv preprint arXiv:2208.08236.
243243
"""
244244

245+
_update_sel_cls = UpdateSel
246+
245247
def __init__(
246248
self,
247249
rcut: float,
@@ -663,7 +665,7 @@ def update_sel(
663665
The minimum distance between two atoms
664666
"""
665667
local_jdata_cpy = local_jdata.copy()
666-
min_nbor_dist, sel = UpdateSel().update_one_sel(
668+
min_nbor_dist, sel = cls._update_sel_cls().update_one_sel(
667669
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
668670
)
669671
local_jdata_cpy["sel"] = sel[0]

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ class DescrptDPA2(NativeOP, BaseDescriptor):
441441
Comput Mater 10, 293 (2024). https://doi.org/10.1038/s41524-024-01493-2
442442
"""
443443

444+
_update_sel_cls = UpdateSel
445+
444446
def __init__(
445447
self,
446448
ntypes: int,
@@ -1115,7 +1117,7 @@ def update_sel(
11151117
The minimum distance between two atoms
11161118
"""
11171119
local_jdata_cpy = local_jdata.copy()
1118-
update_sel = UpdateSel()
1120+
update_sel = cls._update_sel_cls()
11191121
min_nbor_dist, repinit_sel = update_sel.update_one_sel(
11201122
train_data,
11211123
type_map,

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ class DescrptDPA3(NativeOP, BaseDescriptor):
338338
arXiv preprint arXiv:2506.01686 (2025).
339339
"""
340340

341+
_update_sel_cls = UpdateSel
342+
341343
def __init__(
342344
self,
343345
ntypes: int,
@@ -799,7 +801,7 @@ def update_sel(
799801
The minimum distance between two atoms
800802
"""
801803
local_jdata_cpy = local_jdata.copy()
802-
update_sel = UpdateSel()
804+
update_sel = cls._update_sel_cls()
803805
min_nbor_dist, repflow_e_sel = update_sel.update_one_sel(
804806
train_data,
805807
type_map,

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class BD(ABC, PluginVariant, make_plugin_registry("descriptor")):
5151
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
5252
if cls is BD:
5353
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
54-
return super().__new__(cls)
54+
return object.__new__(cls)
5555

5656
@abstractmethod
5757
def get_rcut(self) -> float:

0 commit comments

Comments
 (0)