Skip to content

Commit 51fbd76

Browse files
feat(pt, dpmodel): use data stat for observed type (#5269)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Models now record, persist, and expose observed element-type lists in model metadata; evaluation/display prefer that metadata when present. * Training and model APIs accept an optional preset observed-type override to control observed-type selection. * Utilities added to collect, save, and restore observed-type lists for caching and reuse. * **Tests** * Added unit and integration tests covering collection, file I/O, training integration, metadata propagation, fallback behavior, and user presets. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3b24cf6 commit 51fbd76

22 files changed

Lines changed: 1110 additions & 10 deletions

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,45 @@ def __init__(
6262
self.rcond = rcond
6363
self.preset_out_bias = preset_out_bias
6464
self.data_stat_protect = data_stat_protect
65+
self._observed_type: list[str] | None = None
66+
67+
@property
68+
def observed_type(self) -> list[str] | None:
69+
"""Get the observed element type list from data statistics."""
70+
return self._observed_type
71+
72+
def _collect_and_set_observed_type(
73+
self,
74+
sampled_func: Callable[[], list[dict]],
75+
stat_file_path: DPPath | None,
76+
preset_observed_type: list[str] | None,
77+
) -> None:
78+
"""Collect observed types with priority: preset > stat_file > compute.
79+
80+
Parameters
81+
----------
82+
sampled_func
83+
The lazy sampled function to get data frames.
84+
stat_file_path
85+
The path to the statistics files (should already include type_map suffix).
86+
preset_observed_type
87+
User-specified observed types that take highest priority.
88+
"""
89+
from deepmd.dpmodel.utils.stat import (
90+
_restore_observed_type_from_file,
91+
_save_observed_type_to_file,
92+
collect_observed_types,
93+
)
94+
95+
if preset_observed_type is not None:
96+
self._observed_type = preset_observed_type
97+
else:
98+
observed = _restore_observed_type_from_file(stat_file_path)
99+
if observed is None:
100+
sampled = sampled_func()
101+
observed = collect_observed_types(sampled, self.type_map)
102+
_save_observed_type_to_file(stat_file_path, observed)
103+
self._observed_type = observed
65104

66105
def init_out_stat(self) -> None:
67106
"""Initialize the output bias."""
@@ -271,6 +310,29 @@ def get_compute_stats_distinguish_types(self) -> bool:
271310
"""Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
272311
return True
273312

313+
def compute_or_load_stat(
314+
self,
315+
sampled_func: Callable[[], list[dict]],
316+
stat_file_path: DPPath | None = None,
317+
compute_or_load_out_stat: bool = True,
318+
preset_observed_type: list[str] | None = None,
319+
) -> None:
320+
"""Compute or load the statistics parameters of the model,
321+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
322+
323+
Parameters
324+
----------
325+
sampled_func
326+
The lazy sampled function to get data frames from different data systems.
327+
stat_file_path
328+
The path to the stat file.
329+
compute_or_load_out_stat : bool
330+
Whether to compute the output statistics.
331+
If False, it will only compute the input statistics
332+
(e.g. mean and standard deviation of descriptors).
333+
"""
334+
raise NotImplementedError
335+
274336
def compute_or_load_out_stat(
275337
self,
276338
merged: Callable[[], list[dict]] | list[dict],

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def compute_or_load_stat(
201201
sampled_func: Callable[[], list[dict]],
202202
stat_file_path: DPPath | None = None,
203203
compute_or_load_out_stat: bool = True,
204+
preset_observed_type: list[str] | None = None,
204205
) -> None:
205206
"""Compute or load the statistics parameters of the model,
206207
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
@@ -227,6 +228,10 @@ def compute_or_load_stat(
227228
if compute_or_load_out_stat:
228229
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
229230

231+
self._collect_and_set_observed_type(
232+
wrapped_sampler, stat_file_path, preset_observed_type
233+
)
234+
230235
def change_type_map(
231236
self, type_map: list[str], model_with_new_type_stat: Any | None = None
232237
) -> None:

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def compute_or_load_stat(
349349
sampled_func: Callable[[], list[dict]],
350350
stat_file_path: DPPath | None = None,
351351
compute_or_load_out_stat: bool = True,
352+
preset_observed_type: list[str] | None = None,
352353
) -> None:
353354
"""Compute or load the statistics parameters of the model.
354355
@@ -364,9 +365,21 @@ def compute_or_load_stat(
364365
compute_or_load_out_stat : bool
365366
Whether to compute the output statistics.
366367
"""
368+
# Compute observed type once at parent level, then propagate to
369+
# sub-models via preset_observed_type to avoid redundant computation.
370+
obs_stat_path = stat_file_path
371+
if obs_stat_path is not None and self.type_map is not None:
372+
obs_stat_path = obs_stat_path / " ".join(self.type_map)
373+
self._collect_and_set_observed_type(
374+
sampled_func, obs_stat_path, preset_observed_type
375+
)
376+
367377
for md in self.models:
368378
md.compute_or_load_stat(
369-
sampled_func, stat_file_path, compute_or_load_out_stat=False
379+
sampled_func,
380+
stat_file_path,
381+
compute_or_load_out_stat=False,
382+
preset_observed_type=self._observed_type,
370383
)
371384

372385
if stat_file_path is not None and self.type_map is not None:

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def compute_or_load_stat(
216216
sampled_func: Callable[[], list[dict]],
217217
stat_file_path: DPPath | None = None,
218218
compute_or_load_out_stat: bool = True,
219+
preset_observed_type: list[str] | None = None,
219220
) -> None:
220221
"""Compute or load the statistics parameters of the model.
221222
@@ -235,6 +236,15 @@ def compute_or_load_stat(
235236
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
236237
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
237238

239+
if stat_file_path is not None and self.type_map is not None:
240+
stat_file_path /= " ".join(self.type_map)
241+
242+
self._collect_and_set_observed_type(
243+
sampled_func if callable(sampled_func) else lambda: sampled_func,
244+
stat_file_path,
245+
preset_observed_type,
246+
)
247+
238248
def forward_atomic(
239249
self,
240250
extended_coord: Array,

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
from deepmd.infer.deep_wfc import (
5151
DeepWFC,
5252
)
53+
from deepmd.utils.econf_embd import (
54+
sort_element_type,
55+
)
5356

5457
if TYPE_CHECKING:
5558
import ase.neighborlist
@@ -403,6 +406,31 @@ def get_model_def_script(self) -> dict:
403406
"""Get model definition script."""
404407
return json.loads(self.dp.get_model_def_script())
405408

409+
def get_observed_types(self) -> dict:
410+
"""Get observed types (elements) of the model during data statistics.
411+
412+
Returns
413+
-------
414+
dict
415+
A dictionary containing the information of observed type in the model:
416+
- 'type_num': the total number of observed types in this model.
417+
- 'observed_type': a list of the observed types in this model.
418+
"""
419+
# Try metadata first (from model_def_script)
420+
model_def_script = self.get_model_def_script()
421+
observed_type_list = model_def_script.get("info", {}).get("observed_type")
422+
if observed_type_list is not None:
423+
return {
424+
"type_num": len(observed_type_list),
425+
"observed_type": observed_type_list,
426+
}
427+
# Fallback: bias-based approach for old models
428+
observed_type_list = self.dp.get_observed_type_list()
429+
return {
430+
"type_num": len(observed_type_list),
431+
"observed_type": sort_element_type(observed_type_list),
432+
}
433+
406434
def get_model(self) -> "BaseModel":
407435
"""Get the dpmodel BaseModel.
408436

deepmd/dpmodel/model/make_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ def get_out_bias(self) -> Array:
381381
def get_observed_type_list(self) -> list[str]:
382382
"""Get observed types (elements) of the model during data statistics.
383383
384+
Bias-based fallback for old models without metadata.
385+
384386
Returns
385387
-------
386388
list[str]
@@ -718,6 +720,7 @@ def compute_or_load_stat(
718720
self,
719721
sampled_func: Callable[[], Any],
720722
stat_file_path: DPPath | None = None,
723+
preset_observed_type: list[str] | None = None,
721724
) -> None:
722725
"""Compute or load the statistics parameters of the model.
723726
@@ -728,8 +731,12 @@ def compute_or_load_stat(
728731
data systems.
729732
stat_file_path
730733
The path to the stat file.
734+
preset_observed_type
735+
User-specified observed types that take highest priority.
731736
"""
732-
self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path)
737+
self.atomic_model.compute_or_load_stat(
738+
sampled_func, stat_file_path, preset_observed_type=preset_observed_type
739+
)
733740

734741
def get_model_def_script(self) -> str:
735742
"""Get the model definition script."""

deepmd/dpmodel/utils/stat.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,62 @@
2929
log = logging.getLogger(__name__)
3030

3131

32+
def collect_observed_types(sampled: list[dict], type_map: list[str]) -> list[str]:
33+
"""Collect observed element types from sampled training data.
34+
35+
Parameters
36+
----------
37+
sampled : list[dict]
38+
Sampled data from different data systems. Each dict must contain
39+
``"atype"`` with shape ``[nframes, natoms]``.
40+
type_map : list[str]
41+
Mapping from type index to element symbol.
42+
43+
Returns
44+
-------
45+
list[str]
46+
Sorted list of observed element symbols.
47+
"""
48+
from deepmd.utils.econf_embd import (
49+
sort_element_type,
50+
)
51+
52+
observed_indices: set[int] = set()
53+
for system in sampled:
54+
atype = to_numpy_array(system["atype"]) # shape: [nframes, natoms]
55+
observed_indices.update(np.unique(atype).tolist())
56+
observed_types = [
57+
type_map[i] for i in sorted(observed_indices) if i < len(type_map)
58+
]
59+
return sort_element_type(observed_types)
60+
61+
62+
def _restore_observed_type_from_file(
63+
stat_file_path: DPPath | None,
64+
) -> list[str] | None:
65+
"""Try to load observed_type from stat file."""
66+
if stat_file_path is None:
67+
return None
68+
fp = stat_file_path / "observed_type"
69+
if fp.is_file():
70+
arr = fp.load_numpy()
71+
# Decode bytes back to str if stored as bytes (for h5py compatibility)
72+
return [x.decode() if isinstance(x, bytes) else x for x in arr.tolist()]
73+
return None
74+
75+
76+
def _save_observed_type_to_file(
77+
stat_file_path: DPPath | None, observed_type: list[str]
78+
) -> None:
79+
"""Save observed_type to stat file."""
80+
if stat_file_path is None:
81+
return
82+
stat_file_path.mkdir(exist_ok=True, parents=True)
83+
fp = stat_file_path / "observed_type"
84+
# Use bytes dtype for h5py compatibility (h5py cannot store Unicode strings)
85+
fp.save_numpy(np.array(observed_type, dtype="S"))
86+
87+
3288
def _restore_from_file(
3389
stat_file_path: DPPath,
3490
keys: list[str],

deepmd/entrypoints/show.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ def show(
126126
)
127127
else:
128128
log.info("The observed types for this model: ")
129-
observed_types = model.get_observed_types()
129+
observed_type_list = model_params.get("info", {}).get("observed_type")
130+
if observed_type_list is not None:
131+
observed_types = {
132+
"type_num": len(observed_type_list),
133+
"observed_type": observed_type_list,
134+
}
135+
else:
136+
observed_types = model.get_observed_types()
130137
log.info(f"Number of observed types: {observed_types['type_num']} ")
131138
log.info(f"Observed types: {observed_types['observed_type']} ")

deepmd/pt/infer/deep_eval.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,14 @@ def get_observed_types(self) -> dict:
736736
- 'type_num': the total number of observed types in this model.
737737
- 'observed_type': a list of the observed types in this model.
738738
"""
739+
# Try metadata first (from model_def_script, already a dict)
740+
observed_type_list = self.model_def_script.get("info", {}).get("observed_type")
741+
if observed_type_list is not None:
742+
return {
743+
"type_num": len(observed_type_list),
744+
"observed_type": observed_type_list,
745+
}
746+
# Fallback: bias-based approach for old models
739747
observed_type_list = self.dp.model["Default"].get_observed_type_list()
740748
return {
741749
"type_num": len(observed_type_list),

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,45 @@ def __init__(
9090
self.rcond = rcond
9191
self.preset_out_bias = preset_out_bias
9292
self.data_stat_protect = data_stat_protect
93+
self._observed_type: list[str] | None = None
94+
95+
@property
96+
def observed_type(self) -> list[str] | None:
97+
"""Get the observed element type list from data statistics."""
98+
return self._observed_type
99+
100+
def _collect_and_set_observed_type(
101+
self,
102+
sampled_func: Callable[[], list[dict]],
103+
stat_file_path: "DPPath | None",
104+
preset_observed_type: list[str] | None,
105+
) -> None:
106+
"""Collect observed types with priority: preset > stat_file > compute.
107+
108+
Parameters
109+
----------
110+
sampled_func
111+
The lazy sampled function to get data frames.
112+
stat_file_path
113+
The path to the statistics files (should already include type_map suffix).
114+
preset_observed_type
115+
User-specified observed types that take highest priority.
116+
"""
117+
from deepmd.dpmodel.utils.stat import (
118+
_restore_observed_type_from_file,
119+
_save_observed_type_to_file,
120+
collect_observed_types,
121+
)
122+
123+
if preset_observed_type is not None:
124+
self._observed_type = preset_observed_type
125+
else:
126+
observed = _restore_observed_type_from_file(stat_file_path)
127+
if observed is None:
128+
sampled = sampled_func()
129+
observed = collect_observed_types(sampled, self.type_map)
130+
_save_observed_type_to_file(stat_file_path, observed)
131+
self._observed_type = observed
93132

94133
def init_out_stat(self) -> None:
95134
"""Initialize the output bias."""
@@ -376,6 +415,7 @@ def compute_or_load_stat(
376415
merged: Callable[[], list[dict]] | list[dict],
377416
stat_file_path: DPPath | None = None,
378417
compute_or_load_out_stat: bool = True,
418+
preset_observed_type: list[str] | None = None,
379419
) -> NoReturn:
380420
"""
381421
Compute or load the statistics parameters of the model,

0 commit comments

Comments
 (0)