Skip to content

Commit 345d162

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add eval_typeebd, eval_descriptor, eval_fitting_last_layer (#5391)
## Summary - Add `eval_typeebd`, `eval_descriptor`, `eval_fitting_last_layer` to pt_expt's `DeepEval`, using the eager `_dpmodel` for diagnostic computation - Add `get_dp_atomic_model()` API to the model hierarchy (make_model, FrozenModel, SpinModel) for clean access to the underlying `DPAtomicModel` - Add `set_return_middle_output` to `GeneralFitting` with proper `FittingOutputDef` registration — `fitting_check_output` now evaluates `output_def()` dynamically - Preserve `middle_output` through DipoleFitting/PolarFitting `call()` by modifying the results dict in-place - Add cross-backend reference values for fparam_aparam model (descriptor + fit_ll) - Remove .pte/.pt2 skips in `test_models.py` for `test_descriptor` and `test_fitting_last_layer` ## Limitations - Spin models: `eval_descriptor` and `eval_fitting_last_layer` raise `NotImplementedError` (SpinModel preprocesses coords with virtual atoms) - DPZBLModel: `get_dp_atomic_model()` returns `None` (LinearAtomicModel has no single descriptor/fitting_net) - `model_check_output` still caches `output_def` at init (only `fitting_check_output` is dynamic) — not an issue since `middle_output` is consumed at fitting level - `_middle_output_def()` added to all fitting subclasses but only tested for InvarFitting, DipoleFitting, PolarFitting (not PropertyFitting, DOSFitting) ## Test plan - [x] 10 dpmodel middle_output tests (shape, toggle, output_def registration, decorator validation, dipole/polar passthrough) - [x] 6 eval diagnostic tests (typeebd, descriptor, fitting_last_layer for se_e2_a and DPA1) - [x] 4 get_dp_atomic_model tests (energy, ZBL, spin, frozen) - [x] 3 spin model diagnostic tests (typeebd works, descriptor/fitting raises) - [x] 2 ASE neighbor list consistency tests - [x] Cross-backend consistency via fparam_aparam reference values in test_models.py <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Inference APIs now expose type-embedding, descriptor, and fitting-last-layer outputs; models may return intermediate fitting-network outputs and higher-level fitters propagate them. * New public accessor to retrieve an underlying atomic model from wrapped models. * **Refactor** * Consolidated input preparation and evaluation flow; multi-output handling and output-definition validation unified. * **Tests** * Large expansion of unit and integration tests and test data for new outputs, toggles, determinism, and model variants. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent baab3e8 commit 345d162

15 files changed

Lines changed: 2195 additions & 18 deletions

deepmd/dpmodel/fitting/dipole_fitting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def output_def(self) -> FittingOutputDef:
198198
r_differentiable=self.r_differentiable,
199199
c_differentiable=self.c_differentiable,
200200
),
201+
*self._middle_output_def(),
201202
]
202203
)
203204

@@ -239,15 +240,14 @@ def call(
239240
nframes, nloc, _ = descriptor.shape
240241
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
241242
# (nframes, nloc, m1)
242-
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
243-
self.var_name
244-
]
243+
results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)
244+
out = results[self.var_name]
245245
# (nframes * nloc, 1, m1)
246246
out = xp.reshape(out, (-1, 1, self.embedding_width))
247247
# (nframes * nloc, m1, 3)
248248
gr = xp.reshape(gr, (nframes * nloc, -1, 3))
249249
# (nframes, nloc, 3)
250250
# out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
251251
out = out @ gr
252-
out = xp.reshape(out, (nframes, nloc, 3))
253-
return {self.var_name: out}
252+
results[self.var_name] = xp.reshape(out, (nframes, nloc, 3))
253+
return results

deepmd/dpmodel/fitting/dos_fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def output_def(self) -> FittingOutputDef:
8989
r_differentiable=False,
9090
c_differentiable=False,
9191
),
92+
*self._middle_output_def(),
9293
]
9394
)
9495

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
get_xp_precision,
2525
to_numpy_array,
2626
)
27+
from deepmd.dpmodel.output_def import (
28+
OutputVariableDef,
29+
)
2730
from deepmd.dpmodel.utils import (
2831
AtomExcludeMask,
2932
FittingNet,
@@ -168,6 +171,7 @@ def __init__(
168171
if self.spin is not None:
169172
raise NotImplementedError("spin is not supported")
170173
self.remove_vaccum_contribution = remove_vaccum_contribution
174+
self.eval_return_middle_output = False
171175

172176
net_dim_out = self._net_out_dim()
173177
# init constants
@@ -424,6 +428,39 @@ def get_default_fparam(self) -> list[float] | None:
424428
"""Get the default frame parameters."""
425429
return self.default_fparam
426430

431+
def set_return_middle_output(self, enable: bool) -> None:
432+
"""Enable or disable returning the middle (pre-last-layer) output.
433+
434+
When enabled, the fitting network's ``call`` method will include
435+
a ``"middle_output"`` key in the returned dict, containing the
436+
hidden-layer activations before the final linear layer. Shape:
437+
``[nframes, nloc, neuron[-1]]``.
438+
439+
Raises
440+
------
441+
ValueError
442+
If ``enable`` is True but ``neuron`` is empty (no hidden layers).
443+
"""
444+
if enable and len(self.neuron) == 0:
445+
raise ValueError(
446+
"middle_output requires at least one hidden layer (neuron=[])"
447+
)
448+
self.eval_return_middle_output = enable
449+
450+
def _middle_output_def(self) -> list[OutputVariableDef]:
451+
"""Return extra OutputVariableDefs for middle_output when enabled."""
452+
if self.eval_return_middle_output and len(self.neuron) > 0:
453+
return [
454+
OutputVariableDef(
455+
"middle_output",
456+
[self.neuron[-1]],
457+
reducible=False,
458+
r_differentiable=False,
459+
c_differentiable=False,
460+
),
461+
]
462+
return []
463+
427464
def get_sel_type(self) -> list[int]:
428465
"""Get the selected atom types of this model.
429466
@@ -690,6 +727,12 @@ def _call_common(
690727
dtype=get_xp_precision(xp, self.precision),
691728
device=array_api_compat.device(descriptor),
692729
)
730+
if self.eval_return_middle_output and len(self.neuron) > 0:
731+
middle_outs = xp.zeros(
732+
[nf, nloc, self.neuron[-1]],
733+
dtype=get_xp_precision(xp, self.precision),
734+
device=array_api_compat.device(descriptor),
735+
)
693736
for type_i in range(self.ntypes):
694737
mask = xp.tile(
695738
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
@@ -705,10 +748,20 @@ def _call_common(
705748
mask, atom_property, xp.zeros_like(atom_property)
706749
)
707750
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
751+
if self.eval_return_middle_output and len(self.neuron) > 0:
752+
mid = self.nets[(type_i,)].call_until_last(xx)
753+
mid_mask = xp.tile(
754+
xp.reshape((atype == type_i), (nf, nloc, 1)),
755+
(1, 1, self.neuron[-1]),
756+
)
757+
mid = xp.where(mid_mask, mid, xp.zeros_like(mid))
758+
middle_outs = middle_outs + mid
708759
else:
709760
outs = self.nets[()](xx)
710761
if xx_zeros is not None:
711762
outs -= self.nets[()](xx_zeros)
763+
if self.eval_return_middle_output and len(self.neuron) > 0:
764+
middle_outs = self.nets[()].call_until_last(xx)
712765
outs += xp.reshape(
713766
xp.take(
714767
xp.astype(self.bias_atom_e[...], outs.dtype),
@@ -723,4 +776,6 @@ def _call_common(
723776
# nf x nloc x nod
724777
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
725778
results[self.var_name] = outs
779+
if self.eval_return_middle_output and len(self.neuron) > 0:
780+
results["middle_output"] = middle_outs
726781
return results

deepmd/dpmodel/fitting/invar_fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def output_def(self) -> FittingOutputDef:
210210
r_differentiable=True,
211211
c_differentiable=True,
212212
),
213+
*self._middle_output_def(),
213214
]
214215
)
215216

deepmd/dpmodel/fitting/polarizability_fitting.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def output_def(self) -> FittingOutputDef:
250250
r_differentiable=False,
251251
c_differentiable=False,
252252
),
253+
*self._middle_output_def(),
253254
]
254255
)
255256

@@ -326,9 +327,8 @@ def call(
326327
"Must provide the rotation matrix for polarizability fitting."
327328
)
328329
# (nframes, nloc, _net_out_dim)
329-
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
330-
self.var_name
331-
]
330+
results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)
331+
out = results.pop(self.var_name)
332332
# out = out * self.scale[atype, ...]
333333
scale_atype = xp.reshape(
334334
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0),
@@ -371,4 +371,5 @@ def call(
371371
# (nframes, nloc, 3, 3)
372372
bias = bias[..., None] * eye
373373
out = out + bias
374-
return {"polarizability": out}
374+
results["polarizability"] = out
375+
return results

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def output_def(self) -> FittingOutputDef:
129129
c_differentiable=False,
130130
intensive=self.intensive,
131131
),
132+
*self._middle_output_def(),
132133
]
133134
)
134135

deepmd/dpmodel/model/frozen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
3+
TYPE_CHECKING,
34
Any,
45
NoReturn,
56
)
67

8+
if TYPE_CHECKING:
9+
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
10+
DPAtomicModel,
11+
)
12+
713
from deepmd.dpmodel.common import (
814
NativeOP,
915
)
@@ -131,6 +137,10 @@ def get_observed_type_list(self) -> list[str]:
131137
"""Get observed types (elements) of the model during data statistics."""
132138
return self.model.get_observed_type_list()
133139

140+
def get_dp_atomic_model(self) -> "DPAtomicModel | None":
141+
"""Get the underlying DPAtomicModel by delegating to the inner model."""
142+
return self.model.get_dp_atomic_model()
143+
134144
def serialize(self) -> dict:
135145
"""Serialize the model.
136146

deepmd/dpmodel/model/make_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
Callable,
44
)
55
from typing import (
6+
TYPE_CHECKING,
67
Any,
78
)
89

10+
if TYPE_CHECKING:
11+
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
12+
DPAtomicModel,
13+
)
14+
915
import array_api_compat
1016
import numpy as np
1117

@@ -704,6 +710,21 @@ def is_aparam_nall(self) -> bool:
704710
"""
705711
return self.atomic_model.is_aparam_nall()
706712

713+
def get_dp_atomic_model(self) -> "DPAtomicModel | None":
714+
"""Get the underlying DPAtomicModel with descriptor and fitting_net.
715+
716+
Returns the ``atomic_model`` if it is a ``DPAtomicModel`` instance
717+
(i.e. has both ``descriptor`` and ``fitting_net``). Returns ``None``
718+
for composite atomic models such as ``LinearEnergyAtomicModel``.
719+
"""
720+
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
721+
DPAtomicModel,
722+
)
723+
724+
if isinstance(self.atomic_model, DPAtomicModel):
725+
return self.atomic_model
726+
return None
727+
707728
def get_rcut(self) -> float:
708729
"""Get the cut-off radius."""
709730
return self.atomic_model.get_rcut()

deepmd/dpmodel/model/spin_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,10 @@ def __getattr__(self, name: str) -> Any:
547547
raise AttributeError(name)
548548
return getattr(self.backbone_model, name)
549549

550+
def get_dp_atomic_model(self) -> "DPAtomicModel | None":
551+
"""Get the underlying DPAtomicModel by delegating to the backbone model."""
552+
return self.backbone_model.get_dp_atomic_model()
553+
550554
def serialize(self) -> dict:
551555
return {
552556
"type": "spin_ener",

deepmd/dpmodel/output_def.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ def __call__(
102102
**kwargs: Any,
103103
) -> Any:
104104
ret = cls.__call__(self, *args, **kwargs)
105-
for kk in self.md.keys():
106-
dd = self.md[kk]
105+
md = self.output_def()
106+
for kk in md.keys():
107+
dd = md[kk]
107108
check_var(ret[kk], dd)
108109
return ret
109110

0 commit comments

Comments
 (0)