Skip to content

Commit 19cc00a

Browse files
committed
feat(pt, dpmodel): use data stat for observed type
1 parent 65eea4b commit 19cc00a

12 files changed

Lines changed: 525 additions & 2 deletions

File tree

deepmd/dpmodel/model/base_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ def get_model_def_script(self) -> str:
142142
"""Get the model definition script."""
143143
pass
144144

145+
def get_observed_type_list(self) -> list[str]:
146+
"""Get observed types from model metadata.
147+
148+
Returns empty list if not available.
149+
"""
150+
if self.model_def_script:
151+
import json
152+
153+
params = json.loads(self.model_def_script)
154+
observed = params.get("info", {}).get("observed_type")
155+
if observed is not None:
156+
return observed
157+
return []
158+
145159
def get_min_nbor_dist(self) -> float | None:
146160
"""Get the minimum distance between two atoms."""
147161
return self.min_nbor_dist

deepmd/dpmodel/utils/stat.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,62 @@
2626
log = logging.getLogger(__name__)
2727

2828

29+
def collect_observed_types(
30+
sampled: list[dict], type_map: list[str]
31+
) -> list[str]:
32+
"""Collect observed element types from sampled training data.
33+
34+
Parameters
35+
----------
36+
sampled : list[dict]
37+
Sampled data from different data systems. Each dict must contain
38+
``"atype"`` with shape ``[nframes, natoms]``.
39+
type_map : list[str]
40+
Mapping from type index to element symbol.
41+
42+
Returns
43+
-------
44+
list[str]
45+
Sorted list of observed element symbols.
46+
"""
47+
from deepmd.utils.econf_embd import (
48+
sort_element_type,
49+
)
50+
51+
observed_indices: set[int] = set()
52+
for system in sampled:
53+
atype = to_numpy_array(system["atype"]) # shape: [nframes, natoms]
54+
observed_indices.update(np.unique(atype).tolist())
55+
observed_types = [type_map[i] for i in sorted(observed_indices) if i < len(type_map)]
56+
return sort_element_type(observed_types)
57+
58+
59+
def _restore_observed_type_from_file(
60+
stat_file_path: DPPath | None,
61+
) -> list[str] | None:
62+
"""Try to load observed_type from stat file."""
63+
if stat_file_path is None:
64+
return None
65+
fp = stat_file_path / "observed_type"
66+
if fp.is_file():
67+
arr = fp.load_numpy()
68+
# Decode bytes back to str if stored as bytes (for h5py compatibility)
69+
return [x.decode() if isinstance(x, bytes) else x for x in arr.tolist()]
70+
return None
71+
72+
73+
def _save_observed_type_to_file(
74+
stat_file_path: DPPath | None, observed_type: list[str]
75+
) -> None:
76+
"""Save observed_type to stat file."""
77+
if stat_file_path is None:
78+
return
79+
stat_file_path.mkdir(exist_ok=True, parents=True)
80+
fp = stat_file_path / "observed_type"
81+
# Use bytes dtype for h5py compatibility (h5py cannot store Unicode strings)
82+
fp.save_numpy(np.array(observed_type, dtype="S"))
83+
84+
2985
def _restore_from_file(
3086
stat_file_path: DPPath,
3187
keys: list[str],

deepmd/entrypoints/show.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ def show(
126126
)
127127
else:
128128
log.info("The observed types for this model: ")
129-
observed_types = model.get_observed_types()
129+
observed_type_list = model_params.get("info", {}).get("observed_type")
130+
if observed_type_list is not None:
131+
observed_types = {
132+
"type_num": len(observed_type_list),
133+
"observed_type": observed_type_list,
134+
}
135+
else:
136+
observed_types = model.get_observed_types()
130137
log.info(f"Number of observed types: {observed_types['type_num']} ")
131138
log.info(f"Observed types: {observed_types['observed_type']} ")

deepmd/pt/infer/deep_eval.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,16 @@ def get_observed_types(self) -> dict:
736736
- 'type_num': the total number of observed types in this model.
737737
- 'observed_type': a list of the observed types in this model.
738738
"""
739+
# Try metadata first (from model_def_script, already a dict)
740+
observed_type_list = (
741+
self.model_def_script.get("info", {}).get("observed_type")
742+
)
743+
if observed_type_list is not None:
744+
return {
745+
"type_num": len(observed_type_list),
746+
"observed_type": observed_type_list,
747+
}
748+
# Fallback: bias-based approach for old models
739749
observed_type_list = self.dp.model["Default"].get_observed_type_list()
740750
return {
741751
"type_num": len(observed_type_list),

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
self.rcond = rcond
9191
self.preset_out_bias = preset_out_bias
9292
self.data_stat_protect = data_stat_protect
93+
self._observed_type: list[str] | None = None
9394

9495
def init_out_stat(self) -> None:
9596
"""Initialize the output bias."""

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from deepmd.pt.model.task.base_fitting import (
2121
BaseFitting,
2222
)
23+
from deepmd.pt.utils.stat import (
24+
_restore_observed_type_from_file,
25+
_save_observed_type_to_file,
26+
collect_observed_types,
27+
)
2328
from deepmd.utils.path import (
2429
DPPath,
2530
)
@@ -307,6 +312,7 @@ def compute_or_load_stat(
307312
sampled_func: Callable[[], list[dict]],
308313
stat_file_path: DPPath | None = None,
309314
compute_or_load_out_stat: bool = True,
315+
preset_observed_type: list[str] | None = None,
310316
) -> None:
311317
"""
312318
Compute or load the statistics parameters of the model,
@@ -358,6 +364,17 @@ def wrapped_sampler() -> list[dict]:
358364
if compute_or_load_out_stat:
359365
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
360366

367+
# Collect observed types with priority: preset > stat_file > compute
368+
if preset_observed_type is not None:
369+
self._observed_type = preset_observed_type
370+
else:
371+
observed = _restore_observed_type_from_file(stat_file_path)
372+
if observed is None:
373+
sampled = wrapped_sampler()
374+
observed = collect_observed_types(sampled, self.type_map)
375+
_save_observed_type_to_file(stat_file_path, observed)
376+
self._observed_type = observed
377+
361378
def compute_fitting_input_stat(
362379
self,
363380
sample_merged: Callable[[], list[dict]] | list[dict],

deepmd/pt/model/model/make_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,14 @@ def compute_or_load_stat(
587587
self,
588588
sampled_func: Callable[[], Any],
589589
stat_file_path: DPPath | None = None,
590+
preset_observed_type: list[str] | None = None,
590591
) -> None:
591592
"""Compute or load the statistics."""
592-
return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path)
593+
return self.atomic_model.compute_or_load_stat(
594+
sampled_func,
595+
stat_file_path,
596+
preset_observed_type=preset_observed_type,
597+
)
593598

594599
def get_sel(self) -> list[int]:
595600
"""Returns the number of selected atoms for each type."""

deepmd/pt/model/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def compute_or_load_stat(
3030
self,
3131
sampled_func: Any,
3232
stat_file_path: DPPath | None = None,
33+
preset_observed_type: list[str] | None = None,
3334
) -> NoReturn:
3435
"""
3536
Compute or load the statistics parameters of the model,

deepmd/pt/train/training.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import functools
3+
import json
34
import logging
45
import time
56
from collections.abc import (
@@ -288,6 +289,7 @@ def single_model_stat(
288289
_training_data: DpLoaderSet,
289290
_stat_file_path: str | None,
290291
finetune_has_new_type: bool = False,
292+
preset_observed_type: list[str] | None = None,
291293
) -> Callable[[], Any]:
292294
@functools.lru_cache
293295
def get_sample() -> Any:
@@ -302,6 +304,7 @@ def get_sample() -> Any:
302304
_model.compute_or_load_stat(
303305
sampled_func=get_sample,
304306
stat_file_path=_stat_file_path,
307+
preset_observed_type=preset_observed_type,
305308
)
306309
if isinstance(_stat_file_path, DPH5Path):
307310
_stat_file_path.root.close()
@@ -394,7 +397,16 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
394397
finetune_has_new_type=self.finetune_links["Default"].get_has_new_type()
395398
if self.finetune_links is not None
396399
else False,
400+
preset_observed_type=model_params.get("info", {}).get("observed_type"),
397401
)
402+
# Persist observed_type from stat into model_params and model_def_script
403+
if not resuming and self.rank == 0:
404+
observed = getattr(
405+
self.model.atomic_model, "_observed_type", None
406+
)
407+
if observed is not None:
408+
model_params.setdefault("info", {})["observed_type"] = observed
409+
self.model.model_def_script = json.dumps(model_params)
398410
(
399411
self.training_dataloader,
400412
self.training_data,
@@ -432,6 +444,11 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
432444
training_data[model_key].preload_and_modify_all_data_torch()
433445
if validation_data[model_key] is not None:
434446
validation_data[model_key].preload_and_modify_all_data_torch()
447+
_mt_user_observed = (
448+
model_params["model_dict"][model_key]
449+
.get("info", {})
450+
.get("observed_type")
451+
)
435452
self.get_sample_func[model_key] = single_model_stat(
436453
self.model[model_key],
437454
model_params["model_dict"][model_key].get("data_stat_nbatch", 10),
@@ -442,7 +459,22 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
442459
].get_has_new_type()
443460
if self.finetune_links is not None
444461
else False,
462+
preset_observed_type=_mt_user_observed,
445463
)
464+
# Persist observed_type into model_params and model_def_script
465+
if not resuming and self.rank == 0:
466+
observed = getattr(
467+
self.model[model_key].atomic_model,
468+
"_observed_type",
469+
None,
470+
)
471+
if observed is not None:
472+
model_params["model_dict"][model_key].setdefault(
473+
"info", {}
474+
)["observed_type"] = observed
475+
self.model[model_key].model_def_script = json.dumps(
476+
model_params["model_dict"][model_key]
477+
)
446478

447479
(
448480
self.training_dataloader[model_key],

deepmd/pt/utils/stat.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@
3535

3636
log = logging.getLogger(__name__)
3737

38+
# Re-export from dpmodel (backend-agnostic implementations)
39+
from deepmd.dpmodel.utils.stat import ( # noqa: E402
40+
_restore_observed_type_from_file,
41+
_save_observed_type_to_file,
42+
collect_observed_types,
43+
)
44+
45+
__all__ = [
46+
"collect_observed_types",
47+
"_restore_observed_type_from_file",
48+
"_save_observed_type_to_file",
49+
]
50+
3851

3952
def make_stat_input(
4053
datasets: list[Any], dataloaders: list[Any], nbatches: int

0 commit comments

Comments
 (0)