Skip to content
Merged
10 changes: 5 additions & 5 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def output_def(self) -> FittingOutputDef:
r_differentiable=self.r_differentiable,
c_differentiable=self.c_differentiable,
),
*self._middle_output_def(),
]
)

Expand Down Expand Up @@ -239,15 +240,14 @@ def call(
nframes, nloc, _ = descriptor.shape
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
# (nframes, nloc, m1)
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)
out = results[self.var_name]
# (nframes * nloc, 1, m1)
out = xp.reshape(out, (-1, 1, self.embedding_width))
# (nframes * nloc, m1, 3)
gr = xp.reshape(gr, (nframes * nloc, -1, 3))
# (nframes, nloc, 3)
# out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3)
out = out @ gr
out = xp.reshape(out, (nframes, nloc, 3))
return {self.var_name: out}
results[self.var_name] = xp.reshape(out, (nframes, nloc, 3))
return results
1 change: 1 addition & 0 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def output_def(self) -> FittingOutputDef:
r_differentiable=False,
c_differentiable=False,
),
*self._middle_output_def(),
]
)

Expand Down
46 changes: 46 additions & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.output_def import (
OutputVariableDef,
)
from deepmd.dpmodel.utils import (
AtomExcludeMask,
FittingNet,
Expand Down Expand Up @@ -168,6 +171,7 @@ def __init__(
if self.spin is not None:
raise NotImplementedError("spin is not supported")
self.remove_vaccum_contribution = remove_vaccum_contribution
self.eval_return_middle_output = False

net_dim_out = self._net_out_dim()
# init constants
Expand Down Expand Up @@ -424,6 +428,30 @@ def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return self.default_fparam

def set_return_middle_output(self, enable: bool) -> None:
"""Enable or disable returning the middle (pre-last-layer) output.

When enabled, the fitting network's ``call`` method will include
a ``"middle_output"`` key in the returned dict, containing the
hidden-layer activations before the final linear layer. Shape:
``[nframes, nloc, neuron[-1]]``.
"""
self.eval_return_middle_output = enable

def _middle_output_def(self) -> list[OutputVariableDef]:
"""Return extra OutputVariableDefs for middle_output when enabled."""
if self.eval_return_middle_output:
return [
OutputVariableDef(
"middle_output",
[self.neuron[-1]],
Comment thread
wanghan-iapcm marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
reducible=False,
r_differentiable=False,
c_differentiable=False,
),
]
return []

def get_sel_type(self) -> list[int]:
"""Get the selected atom types of this model.

Expand Down Expand Up @@ -690,6 +718,12 @@ def _call_common(
dtype=get_xp_precision(xp, self.precision),
device=array_api_compat.device(descriptor),
)
if self.eval_return_middle_output:
middle_outs = xp.zeros(
[nf, nloc, self.neuron[-1]],
dtype=get_xp_precision(xp, self.precision),
device=array_api_compat.device(descriptor),
)
for type_i in range(self.ntypes):
mask = xp.tile(
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
Expand All @@ -705,10 +739,20 @@ def _call_common(
mask, atom_property, xp.zeros_like(atom_property)
)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
if self.eval_return_middle_output:
mid = self.nets[(type_i,)].call_until_last(xx)
mid_mask = xp.tile(
xp.reshape((atype == type_i), (nf, nloc, 1)),
(1, 1, self.neuron[-1]),
)
mid = xp.where(mid_mask, mid, xp.zeros_like(mid))
middle_outs = middle_outs + mid
else:
outs = self.nets[()](xx)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
if self.eval_return_middle_output:
middle_outs = self.nets[()].call_until_last(xx)
outs += xp.reshape(
xp.take(
xp.astype(self.bias_atom_e[...], outs.dtype),
Expand All @@ -723,4 +767,6 @@ def _call_common(
# nf x nloc x nod
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
results[self.var_name] = outs
if self.eval_return_middle_output:
results["middle_output"] = middle_outs
return results
1 change: 1 addition & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def output_def(self) -> FittingOutputDef:
r_differentiable=True,
c_differentiable=True,
),
*self._middle_output_def(),
]
)

Expand Down
9 changes: 5 additions & 4 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def output_def(self) -> FittingOutputDef:
r_differentiable=False,
c_differentiable=False,
),
*self._middle_output_def(),
]
)

Expand Down Expand Up @@ -326,9 +327,8 @@ def call(
"Must provide the rotation matrix for polarizability fitting."
)
# (nframes, nloc, _net_out_dim)
out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
self.var_name
]
results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)
out = results.pop(self.var_name)
# out = out * self.scale[atype, ...]
scale_atype = xp.reshape(
xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0),
Expand Down Expand Up @@ -371,4 +371,5 @@ def call(
# (nframes, nloc, 3, 3)
bias = bias[..., None] * eye
out = out + bias
return {"polarizability": out}
results["polarizability"] = out
return results
1 change: 1 addition & 0 deletions deepmd/dpmodel/fitting/property_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def output_def(self) -> FittingOutputDef:
c_differentiable=False,
intensive=self.intensive,
),
*self._middle_output_def(),
]
)

Expand Down
10 changes: 10 additions & 0 deletions deepmd/dpmodel/model/frozen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
TYPE_CHECKING,
Any,
NoReturn,
)

if TYPE_CHECKING:
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)

from deepmd.dpmodel.common import (
NativeOP,
)
Expand Down Expand Up @@ -131,6 +137,10 @@ def get_observed_type_list(self) -> list[str]:
"""Get observed types (elements) of the model during data statistics."""
return self.model.get_observed_type_list()

def get_dp_atomic_model(self) -> "DPAtomicModel | None":
"""Get the underlying DPAtomicModel by delegating to the inner model."""
return self.model.get_dp_atomic_model()

def serialize(self) -> dict:
"""Serialize the model.

Expand Down
21 changes: 21 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
Callable,
)
from typing import (
TYPE_CHECKING,
Any,
)

if TYPE_CHECKING:
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)

import array_api_compat
import numpy as np

Expand Down Expand Up @@ -704,6 +710,21 @@ def is_aparam_nall(self) -> bool:
"""
return self.atomic_model.is_aparam_nall()

def get_dp_atomic_model(self) -> "DPAtomicModel | None":
"""Get the underlying DPAtomicModel with descriptor and fitting_net.

Returns the ``atomic_model`` if it is a ``DPAtomicModel`` instance
(i.e. has both ``descriptor`` and ``fitting_net``). Returns ``None``
for composite atomic models such as ``LinearEnergyAtomicModel``.
"""
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)

if isinstance(self.atomic_model, DPAtomicModel):
return self.atomic_model
return None

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.atomic_model.get_rcut()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError(name)
return getattr(self.backbone_model, name)

def get_dp_atomic_model(self) -> "DPAtomicModel | None":
"""Get the underlying DPAtomicModel by delegating to the backbone model."""
return self.backbone_model.get_dp_atomic_model()

def serialize(self) -> dict:
return {
"type": "spin_ener",
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def __call__(
**kwargs: Any,
) -> Any:
ret = cls.__call__(self, *args, **kwargs)
for kk in self.md.keys():
dd = self.md[kk]
md = self.output_def()
for kk in md.keys():
dd = md[kk]
check_var(ret[kk], dd)
return ret

Expand Down
Loading
Loading