Skip to content

Commit 314d03a

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add dipole, polar, dos, property and dp-zbl models with cross-backend consistency tests (#5260)
## Summary - Add 5 new pt_expt model types: DipoleModel, PolarModel, DOSModel, PropertyModel, and DPZBLModel, completing pt_expt's model coverage to parity with pt - Refactor dpmodel base architecture so that pt_expt models inherit directly from dpmodel via `make_model()`, removing the intermediate pt_expt atomic model layer - Consolidate scattered model-level methods (`get_out_bias`, `set_out_bias`, `get_observed_type_list`, `compute_or_load_stat`) into shared dpmodel base classes - Move `compute_fitting_input_stat` for `set-by-statistic` mode from model-level `change_out_bias` to training-level `model_change_out_bias` (pt and pd backends), keeping the `change_out_bias` logic focused on bias only (copied from #5266) - Fix array-api-compat violations in `general_fitting.change_type_map` (bare `np.zeros`/`np.ones`/`np.concatenate` → `xp` equivalents with device) - Fix dpmodel `change_type_map` not forwarding `model_with_new_type_stat` through the call chain - Add comprehensive cross-backend consistency tests for all model types (dp vs pt vs pt_expt), covering: model output, serialization round-trip, `change_out_bias`, `change_type_map`, `compute_or_load_stat`, model API methods. ## Changes ### New pt_expt models - `deepmd/pt_expt/model/dipole_model.py` - `deepmd/pt_expt/model/polar_model.py` - `deepmd/pt_expt/model/dos_model.py` - `deepmd/pt_expt/model/property_model.py` - `deepmd/pt_expt/model/dp_zbl_model.py` ### Architecture refactoring - Remove `deepmd/pt_expt/atomic_model/` layer — models now wrap dpmodel atomic models directly - Clean up `BaseModel`: remove concrete methods/data, add plugin registry - Refactor `make_model` so backends (dp, pt_expt) inherit shared model logic from dpmodel - Consolidate `get_out_bias`/`set_out_bias` into `base_atomic_model.py` - Add `get_observed_type_list` to abstract API and implement in dpmodel, pt, pd - Move fitting input stat update to `model_change_out_bias` in pt/pd training code (#5266) ### Bug fixes - `general_fitting.change_type_map`: use array-api-compat ops instead of bare numpy (breaks pt_expt) - `make_model.change_type_map`: properly forward `model_with_new_type_stat` to atomic model - `stat.py`: fix in-place mutation issue ### Tests - New cross-backend consistency tests: `test_dipole.py`, `test_polar.py`, `test_dos.py`, `test_property.py` and `test_zbl_ener.py` (~1400 lines each) - Expanded `test_ener.py` with pt_expt and full model API coverage - New pt_expt unit tests: `test_dipole_model.py`, `test_polar_model.py`, `test_dos_model.py`, `test_property_model.py`, `test_dp_zbl_model.py` - Added `test_get_model_def_script`, `test_get_min_nbor_dist`, `test_set_case_embd` across all 6 model test files - Moved atomic model output stat tests from pt_expt to dpmodel - Added `model_change_out_bias` tests in pt/pd training tests (#5266) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * File-backed compute/load for descriptor and fitting statistics; new compute-or-load stat APIs, get/set output-bias, and observed-type discovery. * Exportable/traceable lower-level inference paths for dipole, dos, polar, property, zbl, and energy models. * **Refactor** * Model factory and generated models support extensible base-class composition and unified fitting backend wiring. * **Tests** * Large expansion of cross-backend (DP/PT/PT-EXPT) parity, statistics, bias, and exportability tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent f0a966b commit 314d03a

64 files changed

Lines changed: 10586 additions & 1800 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import functools
23
import math
34
from collections.abc import (
45
Callable,
@@ -52,13 +53,15 @@ def __init__(
5253
pair_exclude_types: list[tuple[int, int]] = [],
5354
rcond: float | None = None,
5455
preset_out_bias: dict[str, Array] | None = None,
56+
data_stat_protect: float = 1e-2,
5557
) -> None:
5658
super().__init__()
5759
self.type_map = type_map
5860
self.reinit_atom_exclude(atom_exclude_types)
5961
self.reinit_pair_exclude(pair_exclude_types)
6062
self.rcond = rcond
6163
self.preset_out_bias = preset_out_bias
64+
self.data_stat_protect = data_stat_protect
6265

6366
def init_out_stat(self) -> None:
6467
"""Initialize the output bias."""
@@ -77,6 +80,14 @@ def init_out_stat(self) -> None:
7780
self.out_bias = out_bias_data
7881
self.out_std = out_std_data
7982

83+
def get_out_bias(self) -> Array:
84+
"""Get the output bias."""
85+
return self.out_bias
86+
87+
def set_out_bias(self, out_bias: Array) -> None:
88+
"""Set the output bias."""
89+
self.out_bias = out_bias
90+
8091
def __setitem__(self, key: str, value: Array) -> None:
8192
if key in ["out_bias"]:
8293
self.out_bias = value
@@ -287,6 +298,57 @@ def compute_or_load_out_stat(
287298
bias_adjust_mode="set-by-statistic",
288299
)
289300

301+
def _make_wrapped_sampler(
302+
self,
303+
sampled_func: Callable[[], list[dict]],
304+
) -> Callable[[], list[dict]]:
305+
"""Wrap the sampled function with exclusion types and default fparam.
306+
307+
The returned callable is cached so that the sampling (which may be
308+
expensive) is performed at most once.
309+
310+
Parameters
311+
----------
312+
sampled_func
313+
The lazy sampled function to get data frames from different data
314+
systems.
315+
316+
Returns
317+
-------
318+
Callable[[], list[dict]]
319+
A cached wrapper around *sampled_func* that additionally sets
320+
``pair_exclude_types``, ``atom_exclude_types`` and default
321+
``fparam`` on every sample dict when applicable.
322+
"""
323+
324+
@functools.lru_cache
325+
def wrapped_sampler() -> list[dict]:
326+
sampled = sampled_func()
327+
if self.pair_excl is not None:
328+
pair_exclude_types = self.pair_excl.get_exclude_types()
329+
for sample in sampled:
330+
sample["pair_exclude_types"] = list(pair_exclude_types)
331+
if self.atom_excl is not None:
332+
atom_exclude_types = self.atom_excl.get_exclude_types()
333+
for sample in sampled:
334+
sample["atom_exclude_types"] = list(atom_exclude_types)
335+
if (
336+
"find_fparam" not in sampled[0]
337+
and "fparam" not in sampled[0]
338+
and self.has_default_fparam()
339+
):
340+
default_fparam = self.get_default_fparam()
341+
if default_fparam is not None:
342+
default_fparam_np = np.array(default_fparam)
343+
for sample in sampled:
344+
nframe = sample["atype"].shape[0]
345+
sample["fparam"] = np.tile(
346+
default_fparam_np.reshape(1, -1), (nframe, 1)
347+
)
348+
return sampled
349+
350+
return wrapped_sampler
351+
290352
def change_out_bias(
291353
self,
292354
sample_merged: Callable[[], list[dict]] | list[dict],

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from collections.abc import (
3+
Callable,
4+
)
25
from typing import (
36
Any,
47
)
@@ -15,6 +18,9 @@
1518
from deepmd.dpmodel.output_def import (
1619
FittingOutputDef,
1720
)
21+
from deepmd.utils.path import (
22+
DPPath,
23+
)
1824
from deepmd.utils.version import (
1925
check_version_compatibility,
2026
)
@@ -62,17 +68,16 @@ def __init__(
6268
**kwargs: Any,
6369
) -> None:
6470
super().__init__(type_map, **kwargs)
65-
self.type_map = type_map
6671
self.descriptor = descriptor
67-
self.fitting = fitting
68-
if hasattr(self.fitting, "reinit_exclude"):
69-
self.fitting.reinit_exclude(self.atom_exclude_types)
72+
self.fitting_net = fitting
73+
if hasattr(self.fitting_net, "reinit_exclude"):
74+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
7075
self.type_map = type_map
7176
super().init_out_stat()
7277

7378
def fitting_output_def(self) -> FittingOutputDef:
7479
"""Get the output def of the fitting net."""
75-
return self.fitting.output_def()
80+
return self.fitting_net.output_def()
7681

7782
def get_rcut(self) -> float:
7883
"""Get the cut-off radius."""
@@ -87,7 +92,7 @@ def set_case_embd(self, case_idx: int) -> None:
8792
Set the case embedding of this atomic model by the given case_idx,
8893
typically concatenated with the output of the descriptor and fed into the fitting net.
8994
"""
90-
self.fitting.set_case_embd(case_idx)
95+
self.fitting_net.set_case_embd(case_idx)
9196

9297
def mixed_types(self) -> bool:
9398
"""If true, the model
@@ -180,7 +185,7 @@ def forward_atomic(
180185
nlist,
181186
mapping=mapping,
182187
)
183-
ret = self.fitting(
188+
ret = self.fitting_net(
184189
descriptor,
185190
atype,
186191
gr=rot_mat,
@@ -191,6 +196,37 @@ def forward_atomic(
191196
)
192197
return ret
193198

199+
def compute_or_load_stat(
200+
self,
201+
sampled_func: Callable[[], list[dict]],
202+
stat_file_path: DPPath | None = None,
203+
compute_or_load_out_stat: bool = True,
204+
) -> None:
205+
"""Compute or load the statistics parameters of the model,
206+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
207+
208+
Parameters
209+
----------
210+
sampled_func
211+
The lazy sampled function to get data frames from different data systems.
212+
stat_file_path
213+
The path to the stat file.
214+
compute_or_load_out_stat : bool
215+
Whether to compute the output statistics.
216+
If False, it will only compute the input statistics
217+
(e.g. mean and standard deviation of descriptors).
218+
"""
219+
if stat_file_path is not None and self.type_map is not None:
220+
stat_file_path /= " ".join(self.type_map)
221+
222+
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
223+
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
224+
self.fitting_net.compute_input_stats(
225+
wrapped_sampler, stat_file_path=stat_file_path
226+
)
227+
if compute_or_load_out_stat:
228+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
229+
194230
def change_type_map(
195231
self, type_map: list[str], model_with_new_type_stat: Any | None = None
196232
) -> None:
@@ -207,7 +243,31 @@ def change_type_map(
207243
if model_with_new_type_stat is not None
208244
else None,
209245
)
210-
self.fitting.change_type_map(type_map=type_map)
246+
self.fitting_net.change_type_map(type_map=type_map)
247+
248+
def compute_fitting_input_stat(
249+
self,
250+
sample_merged: Callable[[], list[dict]] | list[dict],
251+
stat_file_path: DPPath | None = None,
252+
) -> None:
253+
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
254+
255+
Parameters
256+
----------
257+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
258+
- list[dict]: A list of data samples from various data systems.
259+
Each element, ``merged[i]``, is a data dictionary containing
260+
``keys``: ``np.ndarray`` originating from the ``i``-th data system.
261+
- Callable[[], list[dict]]: A lazy function that returns data samples
262+
in the above format only when needed.
263+
stat_file_path : Optional[DPPath]
264+
The path to the stat file.
265+
"""
266+
self.fitting_net.compute_input_stats(
267+
sample_merged,
268+
protection=self.data_stat_protect,
269+
stat_file_path=stat_file_path,
270+
)
211271

212272
def serialize(self) -> dict:
213273
dd = super().serialize()
@@ -218,7 +278,7 @@ def serialize(self) -> dict:
218278
"@version": 2,
219279
"type_map": self.type_map,
220280
"descriptor": self.descriptor.serialize(),
221-
"fitting": self.fitting.serialize(),
281+
"fitting": self.fitting_net.serialize(),
222282
}
223283
)
224284
return dd
@@ -244,19 +304,19 @@ def deserialize(cls, data: dict[str, Any]) -> "DPAtomicModel":
244304

245305
def get_dim_fparam(self) -> int:
246306
"""Get the number (dimension) of frame parameters of this atomic model."""
247-
return self.fitting.get_dim_fparam()
307+
return self.fitting_net.get_dim_fparam()
248308

249309
def get_dim_aparam(self) -> int:
250310
"""Get the number (dimension) of atomic parameters of this atomic model."""
251-
return self.fitting.get_dim_aparam()
311+
return self.fitting_net.get_dim_aparam()
252312

253313
def has_default_fparam(self) -> bool:
254314
"""Check if the model has default frame parameters."""
255-
return self.fitting.has_default_fparam()
315+
return self.fitting_net.has_default_fparam()
256316

257317
def get_default_fparam(self) -> list[float] | None:
258318
"""Get the default frame parameters."""
259-
return self.fitting.get_default_fparam()
319+
return self.fitting_net.get_default_fparam()
260320

261321
def get_sel_type(self) -> list[int]:
262322
"""Get the selected atom types of this model.
@@ -265,7 +325,7 @@ def get_sel_type(self) -> list[int]:
265325
to the result of the model.
266326
If returning an empty list, all atom types are selected.
267327
"""
268-
return self.fitting.get_sel_type()
328+
return self.fitting_net.get_sel_type()
269329

270330
def is_aparam_nall(self) -> bool:
271331
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from collections.abc import (
3+
Callable,
4+
)
25
from typing import (
36
Any,
47
)
@@ -17,6 +20,9 @@
1720
from deepmd.env import (
1821
GLOBAL_NP_FLOAT_PRECISION,
1922
)
23+
from deepmd.utils.path import (
24+
DPPath,
25+
)
2026
from deepmd.utils.version import (
2127
check_version_compatibility,
2228
)
@@ -338,6 +344,38 @@ def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
338344
data["models"] = models
339345
return super().deserialize(data)
340346

347+
def compute_or_load_stat(
348+
self,
349+
sampled_func: Callable[[], list[dict]],
350+
stat_file_path: DPPath | None = None,
351+
compute_or_load_out_stat: bool = True,
352+
) -> None:
353+
"""Compute or load the statistics parameters of the model.
354+
355+
For LinearEnergyAtomicModel, this first computes input stats for each
356+
sub-model (without output stats), then computes its own output stats.
357+
358+
Parameters
359+
----------
360+
sampled_func
361+
The lazy sampled function to get data frames from different data systems.
362+
stat_file_path
363+
The path to the stat file.
364+
compute_or_load_out_stat : bool
365+
Whether to compute the output statistics.
366+
"""
367+
for md in self.models:
368+
md.compute_or_load_stat(
369+
sampled_func, stat_file_path, compute_or_load_out_stat=False
370+
)
371+
372+
if stat_file_path is not None and self.type_map is not None:
373+
stat_file_path /= " ".join(self.type_map)
374+
375+
if compute_or_load_out_stat:
376+
wrapped_sampler = self._make_wrapped_sampler(sampled_func)
377+
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
378+
341379
def _compute_weight(
342380
self,
343381
extended_coord: Array,
@@ -523,4 +561,4 @@ def _compute_weight(
523561
# to handle masked atoms
524562
coef = xp.where(sigma != 0, coef, xp.zeros_like(coef))
525563
self.zbl_weight = coef
526-
return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)]
564+
return [1 - xp.expand_dims(coef, axis=-1), xp.expand_dims(coef, axis=-1)]

0 commit comments

Comments
 (0)