From 129c9bc761b30962a616837ee729809bd78662fc Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 00:17:17 +0800 Subject: [PATCH 1/5] first implement hessian like the pt backend --- deepmd/pt_expt/model/__init__.py | 4 + deepmd/pt_expt/model/ener_model.py | 18 +- deepmd/pt_expt/model/make_hessian_model.py | 232 ++++++++++++++++ .../consistent/model/test_ener_hessian.py | 252 ++++++++++++++++++ .../pt_expt/model/test_make_hessian_model.py | 240 +++++++++++++++++ 5 files changed, 745 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/model/make_hessian_model.py create mode 100644 source/tests/consistent/model/test_ener_hessian.py create mode 100644 source/tests/pt_expt/model/test_make_hessian_model.py diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py index da120091e0..3067b257b5 100644 --- a/deepmd/pt_expt/model/__init__.py +++ b/deepmd/pt_expt/model/__init__.py @@ -11,6 +11,9 @@ from .ener_model import ( EnergyModel, ) +from .make_hessian_model import ( + make_hessian_model, +) from .model import ( BaseModel, ) @@ -29,4 +32,5 @@ "EnergyModel", "PolarModel", "PropertyModel", + "make_hessian_model", ] diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 271028d2ff..bc49ea5a01 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -15,6 +15,9 @@ DPModelCommon, ) +from .make_hessian_model import ( + make_hessian_model, +) from .make_model import ( make_model, ) @@ -34,6 +37,13 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + + def enable_hessian(self) -> None: + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.requires_hessian("energy") + self._hessian_enabled = True def forward( self, @@ -44,7 +54,9 @@ def forward( aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: - model_ret = self.call_common( + # Use forward_common (not call_common) so that hessian toggle + # logic in make_hessian_model.CM.forward_common is applied. + model_ret = self.forward_common( coord, atype, box, @@ -63,6 +75,8 @@ def forward( model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3) return model_predict def forward_lower( @@ -115,6 +129,8 @@ def translated_output_def(self) -> dict[str, Any]: output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] + if self._hessian_enabled: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def def forward_lower_exportable( diff --git a/deepmd/pt_expt/model/make_hessian_model.py b/deepmd/pt_expt/model/make_hessian_model.py new file mode 100644 index 0000000000..3299490934 --- /dev/null +++ b/deepmd/pt_expt/model/make_hessian_model.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import math +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + get_hessian_name, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) + + +def make_hessian_model(T_Model: type) -> type: + """Make a model that can compute Hessian. + + LIMITATION: only the hessian of ``forward_common`` is available. + + Parameters + ---------- + T_Model + The model. Should provide the ``forward_common`` and + ``atomic_output_def`` methods. + + Returns + ------- + The model computes hessian. + + """ + + class CM(T_Model): + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + **kwargs, + ) + self.hess_fitting_def = copy.deepcopy(super().atomic_output_def()) + + def requires_hessian( + self, + keys: str | list[str], + ) -> None: + """Set which output variable(s) requires hessian.""" + if isinstance(keys, str): + keys = [keys] + for kk in self.hess_fitting_def.keys(): + if kk in keys: + self.hess_fitting_def[kk].r_hessian = True + + def atomic_output_def(self) -> FittingOutputDef: + """Get the fitting output def.""" + return self.hess_fitting_def + + def forward_common( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,torch.Tensor]. + The keys are defined by the ``ModelOutputDef``. + + """ + vdef = self.atomic_output_def() + hess_keys = [kk for kk in vdef.keys() if vdef[kk].r_hessian] + # Temporarily disable r_hessian so that the base forward_common + # (which goes through communicate_extended_output) does not expect + # hessian keys in the lower-level output. Must remain disabled + # during _cal_hessian_all as well, since the wrapper also calls + # super().forward_common internally. + for kk in hess_keys: + vdef[kk].r_hessian = False + try: + ret = super().forward_common( + coord, + atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if hess_keys: + hess = self._cal_hessian_all( + hess_keys, + coord, + atype, + box=box, + fparam=fparam, + aparam=aparam, + ) + ret.update(hess) + finally: + for kk in hess_keys: + vdef[kk].r_hessian = True + return ret + + def _cal_hessian_all( + self, + hess_keys: list[str], + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + nf, nloc = atype.shape + coord = coord.view([nf, (nloc * 3)]) + box = box.view([nf, 9]) if box is not None else None + fparam = fparam.view([nf, -1]) if fparam is not None else None + aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None + fdef = self.atomic_output_def() + # result dict init by empty lists + res = {get_hessian_name(kk): [] for kk in hess_keys} + # loop over variable + for kk in hess_keys: + vdef = fdef[kk] + vshape = vdef.shape + vsize = math.prod(vdef.shape) + # loop over frames + for ii in range(nf): + icoord = coord[ii] + iatype = atype[ii] + ibox = box[ii] if box is not None else None + ifparam = fparam[ii] if fparam is not None else None + iaparam = aparam[ii] if aparam is not None else None + # loop over all components + for idx in range(vsize): + hess = self._cal_hessian_one_component( + idx, icoord, iatype, ibox, ifparam, iaparam + ) + res[get_hessian_name(kk)].append(hess) + res[get_hessian_name(kk)] = torch.stack(res[get_hessian_name(kk)]).view( + (nf, *vshape, nloc * 3, nloc * 3) + ) + return res + + def _cal_hessian_one_component( + self, + ci: int, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> torch.Tensor: + # coord, # (nloc x 3) + # atype, # nloc + # box: Optional[torch.Tensor] = None, # 9 + # fparam: Optional[torch.Tensor] = None, # nfp + # aparam: Optional[torch.Tensor] = None, # (nloc x nap) + wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam) + hess = torch.autograd.functional.hessian( + wc, + coord, + create_graph=self.training, + ) + return hess + + class wrapper_class_forward_energy: + def __init__( + self, + obj: CM, + ci: int, + atype: torch.Tensor, + box: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> None: + self.atype, self.box, self.fparam, self.aparam = ( + atype, + box, + fparam, + aparam, + ) + self.ci = ci + self.obj = obj + + def __call__( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + ci = self.ci + atype, box, fparam, aparam = ( + self.atype, + self.box, + self.fparam, + self.aparam, + ) + res = super(CM, self.obj).forward_common( + xx.unsqueeze(0), + atype.unsqueeze(0), + box.unsqueeze(0) if box is not None else None, + fparam.unsqueeze(0) if fparam is not None else None, + aparam.unsqueeze(0) if aparam is not None else None, + do_atomic_virial=False, + ) + er = res["energy_redu"][0].view([-1])[ci] + return er + + return CM diff --git a/source/tests/consistent/model/test_ener_hessian.py b/source/tests/consistent/model/test_ener_hessian.py new file mode 100644 index 0000000000..1181cac114 --- /dev/null +++ b/source/tests/consistent/model/test_ener_hessian.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_PT_EXPT, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT +else: + EnergyModelPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.model import EnergyModel as EnergyModelPTExpt +else: + EnergyModelPTExpt = None +from deepmd.utils.argcheck import ( + model_args, +) + + +class TestEnerHessian(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = None + dp_class = EnergyModelDP + pt_class = EnergyModelPT + pt_expt_class = EnergyModelPTExpt + jax_class = None + pd_class = None + args = model_args() + + @property + def skip_tf(self) -> bool: + return True + + @property + def skip_dp(self) -> bool: + return True + + @property + def skip_jax(self) -> bool: + return True + + @property + def skip_pd(self) -> bool: + return True + + @property + def skip_array_api_strict(self) -> bool: + return True + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can compute hessian. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT + raise ValueError("No available reference") + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class and enable hessian.""" + data = data.copy() + if cls is EnergyModelDP: + model = get_model_dp(data) + elif cls is EnergyModelPT: + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + elif cls is EnergyModelPTExpt: + dp_model = get_model_dp(data) + model = EnergyModelPTExpt.deserialize(dp_model.serialize()) + else: + model = cls(**data) + model.enable_hessian() + return model + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + raise NotImplementedError("no TF in this test") + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_model( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + if backend in { + self.RefBackend.PT, + self.RefBackend.PT_EXPT, + }: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["hessian"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") + + def _enable_hessian_on(self, obj): + """Enable hessian on a model object.""" + obj.enable_hessian() + return obj + + def test_pt_consistent_with_ref(self) -> None: + """Test PT consistent with reference, re-enabling hessian after deserialize.""" + if self.skip_pt: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self._enable_hessian_on(self.pt_class.deserialize(data1)) + ret2 = self.eval_pt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + + def test_pt_expt_consistent_with_ref(self) -> None: + """Test pt_expt consistent with reference, re-enabling hessian after deserialize.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT_EXPT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self._enable_hessian_on(self.pt_expt_class.deserialize(data1)) + ret2 = self.eval_pt_expt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT_EXPT) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + + def test_pt_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + def test_pt_expt_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/model/test_make_hessian_model.py b/source/tests/pt_expt/model/test_make_hessian_model.py new file mode 100644 index 0000000000..7737c95786 --- /dev/null +++ b/source/tests/pt_expt/model/test_make_hessian_model.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.dpmodel.model.ener_model import EnergyModel as DPEnergyModel +from deepmd.dpmodel.output_def import ( + OutputVariableCategory, +) +from deepmd.pt_expt.common import ( + to_torch_array, +) +from deepmd.pt_expt.model import ( + EnergyModel, + make_hessian_model, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + +dtype = torch.float64 + + +def to_numpy_array(xx): + if isinstance(xx, torch.Tensor): + return xx.detach().cpu().numpy() + return np.asarray(xx) + + +def finite_hessian(f, x, delta=1e-6): + in_shape = x.shape + assert len(in_shape) == 1 + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape + in_shape) + for iidx in np.ndindex(*in_shape): + for jidx in np.ndindex(*in_shape): + i0 = np.zeros(in_shape) + i1 = np.zeros(in_shape) + i2 = np.zeros(in_shape) + i3 = np.zeros(in_shape) + i0[iidx] += delta + i2[iidx] += delta + i1[iidx] -= delta + i3[iidx] -= delta + i0[jidx] += delta + i1[jidx] += delta + i2[jidx] -= delta + i3[jidx] -= delta + y0 = f(x + i0) + y1 = f(x + i1) + y2 = f(x + i2) + y3 = f(x + i3) + res[(Ellipsis, *iidx, *jidx)] = (y0 + y3 - y1 - y2) / (4 * delta**2.0) + return res + + +class HessianTest: + def test( + self, + ) -> None: + # setup test case + places = 6 + delta = 1e-3 + natoms = self.nloc + nf = self.nf + nv = self.nv + generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) + cell0 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE, generator=generator) + cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * torch.eye(3, device=env.DEVICE) + cell1 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE, generator=generator) + cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * torch.eye(3, device=env.DEVICE) + cell = torch.stack([cell0, cell1]) + coord = torch.rand( + [nf, natoms, 3], dtype=dtype, device=env.DEVICE, generator=generator + ) + coord = torch.matmul(coord, cell) + cell = cell.view([nf, 9]) + coord = coord.view([nf, natoms * 3]) + atype = ( + torch.stack( + [ + torch.IntTensor([0, 0, 1]), + torch.IntTensor([1, 0, 1]), + ] + ) + .view([nf, natoms]) + .to(env.DEVICE) + ) + nfp, nap = 2, 3 + fparam = torch.rand( + [nf, nfp], dtype=dtype, device=env.DEVICE, generator=generator + ) + aparam = torch.rand( + [nf, natoms, nap], dtype=dtype, device=env.DEVICE, generator=generator + ) + # forward hess and value models + # pt_expt requires coord to have requires_grad=True for autograd-based + # force/virial computation in forward_common + coord = coord.requires_grad_(True) + ret_dict0 = self.model_hess.forward_common( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + ret_dict1 = self.model_valu.forward_common( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + # compare hess and value models + torch.testing.assert_close(ret_dict0["energy"], ret_dict1["energy"]) + ana_hess = ret_dict0["energy_derv_r_derv_r"] + + # compute finite difference + fnt_hess = [] + for ii in range(nf): + + def np_infer( + xx, + ): + xx_t = to_torch_array(xx).unsqueeze(0).requires_grad_(True) + ret = self.model_valu.forward_common( + xx_t, + atype[ii].unsqueeze(0), + box=cell[ii].unsqueeze(0), + fparam=fparam[ii].unsqueeze(0), + aparam=aparam[ii].unsqueeze(0), + ) + # detach + ret = {kk: to_numpy_array(ret[kk]) for kk in ret} + return ret + + def ff(xx): + return np_infer(xx)["energy_redu"] + + xx = to_numpy_array(coord[ii]) + fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) + + # compare finite difference with autodiff + fnt_hess = np.stack(fnt_hess).reshape([nf, nv, natoms * 3, natoms * 3]) + np.testing.assert_almost_equal( + fnt_hess, to_numpy_array(ana_hess), decimal=places + ) + + +class TestDPModel(unittest.TestCase, HessianTest): + def setUp(self) -> None: + torch.manual_seed(2) + self.nf = 2 + self.nloc = 3 + self.rcut = 4.0 + self.rcut_smth = 3.0 + self.sel = [10, 10] + self.nt = 2 + self.nv = 2 + type_map = ["foo", "bar"] + # Build dpmodel first, then deserialize into pt_expt + ds_dp = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ) + ft_dp = DPInvarFitting( + "energy", + self.nt, + ds_dp.get_dim_out(), + self.nv, + mixed_types=ds_dp.mixed_types(), + numb_fparam=2, + numb_aparam=3, + neuron=[4, 4, 4], + ) + md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=type_map) + serialized = md_dp.serialize() + # Create hessian model via make_hessian_model + HessEnergyModel = make_hessian_model(EnergyModel) + self.model_hess = HessEnergyModel.deserialize(serialized).to(env.DEVICE) + self.model_hess.requires_hessian("energy") + # Create value model (no hessian) + self.model_valu = EnergyModel.deserialize(serialized).to(env.DEVICE) + + def test_output_def(self) -> None: + self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) + self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) + self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) + self.assertEqual( + self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, + OutputVariableCategory.DERV_R_DERV_R, + ) + + +class TestEnableHessian(unittest.TestCase, HessianTest): + """Test hessian via enable_hessian() method.""" + + def setUp(self) -> None: + torch.manual_seed(2) + self.nf = 2 + self.nloc = 3 + self.rcut = 4.0 + self.rcut_smth = 3.0 + self.sel = [10, 10] + self.nt = 2 + self.nv = 1 + type_map = ["foo", "bar"] + ds_dp = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ) + ft_dp = DPInvarFitting( + "energy", + self.nt, + ds_dp.get_dim_out(), + self.nv, + mixed_types=ds_dp.mixed_types(), + numb_fparam=2, + numb_aparam=3, + neuron=[4, 4, 4], + ) + md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=type_map) + serialized = md_dp.serialize() + self.model_hess = EnergyModel.deserialize(serialized).to(env.DEVICE) + self.model_hess.enable_hessian() + self.model_valu = EnergyModel.deserialize(serialized).to(env.DEVICE) + + def test_output_def(self) -> None: + self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) + self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) + + +if __name__ == "__main__": + unittest.main() From 61731868725cd6427659c429ee94f72902329e6a Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 14:29:26 +0800 Subject: [PATCH 2/5] implement following jax --- deepmd/pt_expt/model/ener_model.py | 8 +- deepmd/pt_expt/model/make_hessian_model.py | 187 +---------------- deepmd/pt_expt/model/make_model.py | 160 +++++++++++++- .../consistent/model/test_ener_hessian.py | 51 ++++- ...an_model.py => test_ener_hessian_model.py} | 198 ++++++++---------- 5 files changed, 305 insertions(+), 299 deletions(-) rename source/tests/pt_expt/model/{test_make_hessian_model.py => test_ener_hessian_model.py} (61%) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index bc49ea5a01..07cdfc5a54 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -54,9 +54,7 @@ def forward( aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: - # Use forward_common (not call_common) so that hessian toggle - # logic in make_hessian_model.CM.forward_common is applied. - model_ret = self.forward_common( + model_ret = self.call_common( coord, atype, box, @@ -75,7 +73,7 @@ def forward( model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] - if self._hessian_enabled: + if self.atomic_output_def()["energy"].r_hessian: model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3) return model_predict @@ -129,7 +127,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] - if self._hessian_enabled: + if self.atomic_output_def()["energy"].r_hessian: output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def diff --git a/deepmd/pt_expt/model/make_hessian_model.py b/deepmd/pt_expt/model/make_hessian_model.py index 3299490934..10e46ebfc8 100644 --- a/deepmd/pt_expt/model/make_hessian_model.py +++ b/deepmd/pt_expt/model/make_hessian_model.py @@ -1,15 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy -import math from typing import ( Any, ) -import torch - -from deepmd.dpmodel import ( - get_hessian_name, -) from deepmd.dpmodel.output_def import ( FittingOutputDef, ) @@ -18,17 +12,20 @@ def make_hessian_model(T_Model: type) -> type: """Make a model that can compute Hessian. - LIMITATION: only the hessian of ``forward_common`` is available. + With the JAX-mirrored approach, hessian is computed in + ``forward_common_atomic`` (in make_model.py) on extended coordinates. + This wrapper only needs to override ``atomic_output_def()`` to set + ``r_hessian=True``, and ``communicate_extended_output`` in dpmodel + naturally maps it from nall to nloc. Parameters ---------- T_Model - The model. Should provide the ``forward_common`` and - ``atomic_output_def`` methods. + The model. Should provide the ``atomic_output_def`` method. Returns ------- - The model computes hessian. + The model that computes hessian. """ @@ -59,174 +56,4 @@ def atomic_output_def(self) -> FittingOutputDef: """Get the fitting output def.""" return self.hess_fitting_def - def forward_common( - self, - coord: torch.Tensor, - atype: torch.Tensor, - box: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, - ) -> dict[str, torch.Tensor]: - """Return model prediction. - - Parameters - ---------- - coord - The coordinates of the atoms. - shape: nf x (nloc x 3) - atype - The type of atoms. shape: nf x nloc - box - The simulation box. shape: nf x 9 - fparam - frame parameter. nf x ndf - aparam - atomic parameter. nf x nloc x nda - do_atomic_virial - If calculate the atomic virial. - - Returns - ------- - ret_dict - The result dict of type dict[str,torch.Tensor]. - The keys are defined by the ``ModelOutputDef``. - - """ - vdef = self.atomic_output_def() - hess_keys = [kk for kk in vdef.keys() if vdef[kk].r_hessian] - # Temporarily disable r_hessian so that the base forward_common - # (which goes through communicate_extended_output) does not expect - # hessian keys in the lower-level output. Must remain disabled - # during _cal_hessian_all as well, since the wrapper also calls - # super().forward_common internally. - for kk in hess_keys: - vdef[kk].r_hessian = False - try: - ret = super().forward_common( - coord, - atype, - box=box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, - ) - if hess_keys: - hess = self._cal_hessian_all( - hess_keys, - coord, - atype, - box=box, - fparam=fparam, - aparam=aparam, - ) - ret.update(hess) - finally: - for kk in hess_keys: - vdef[kk].r_hessian = True - return ret - - def _cal_hessian_all( - self, - hess_keys: list[str], - coord: torch.Tensor, - atype: torch.Tensor, - box: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - nf, nloc = atype.shape - coord = coord.view([nf, (nloc * 3)]) - box = box.view([nf, 9]) if box is not None else None - fparam = fparam.view([nf, -1]) if fparam is not None else None - aparam = aparam.view([nf, nloc, -1]) if aparam is not None else None - fdef = self.atomic_output_def() - # result dict init by empty lists - res = {get_hessian_name(kk): [] for kk in hess_keys} - # loop over variable - for kk in hess_keys: - vdef = fdef[kk] - vshape = vdef.shape - vsize = math.prod(vdef.shape) - # loop over frames - for ii in range(nf): - icoord = coord[ii] - iatype = atype[ii] - ibox = box[ii] if box is not None else None - ifparam = fparam[ii] if fparam is not None else None - iaparam = aparam[ii] if aparam is not None else None - # loop over all components - for idx in range(vsize): - hess = self._cal_hessian_one_component( - idx, icoord, iatype, ibox, ifparam, iaparam - ) - res[get_hessian_name(kk)].append(hess) - res[get_hessian_name(kk)] = torch.stack(res[get_hessian_name(kk)]).view( - (nf, *vshape, nloc * 3, nloc * 3) - ) - return res - - def _cal_hessian_one_component( - self, - ci: int, - coord: torch.Tensor, - atype: torch.Tensor, - box: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - ) -> torch.Tensor: - # coord, # (nloc x 3) - # atype, # nloc - # box: Optional[torch.Tensor] = None, # 9 - # fparam: Optional[torch.Tensor] = None, # nfp - # aparam: Optional[torch.Tensor] = None, # (nloc x nap) - wc = wrapper_class_forward_energy(self, ci, atype, box, fparam, aparam) - hess = torch.autograd.functional.hessian( - wc, - coord, - create_graph=self.training, - ) - return hess - - class wrapper_class_forward_energy: - def __init__( - self, - obj: CM, - ci: int, - atype: torch.Tensor, - box: torch.Tensor | None, - fparam: torch.Tensor | None, - aparam: torch.Tensor | None, - ) -> None: - self.atype, self.box, self.fparam, self.aparam = ( - atype, - box, - fparam, - aparam, - ) - self.ci = ci - self.obj = obj - - def __call__( - self, - xx: torch.Tensor, - ) -> torch.Tensor: - ci = self.ci - atype, box, fparam, aparam = ( - self.atype, - self.box, - self.fparam, - self.aparam, - ) - res = super(CM, self.obj).forward_common( - xx.unsqueeze(0), - atype.unsqueeze(0), - box.unsqueeze(0) if box is not None else None, - fparam.unsqueeze(0) if fparam is not None else None, - aparam.unsqueeze(0) if aparam is not None else None, - do_atomic_virial=False, - ) - er = res["energy_redu"][0].view([-1])[ci] - return er - return CM diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 56cabafe81..8eda3472e6 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -1,14 +1,21 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import math from typing import ( Any, ) import torch +from deepmd.dpmodel import ( + get_hessian_name, +) from deepmd.dpmodel.atomic_model.base_atomic_model import ( BaseAtomicModel, ) from deepmd.dpmodel.model.make_model import make_model as make_model_dp +from deepmd.dpmodel.output_def import ( + OutputVariableDef, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -18,6 +25,136 @@ ) +def _cal_hessian_ext( + model: Any, + kk: str, + vdef: OutputVariableDef, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + create_graph: bool = False, +) -> torch.Tensor: + """Compute hessian of reduced output w.r.t. extended coordinates. + + Mirrors the JAX approach: compute hessian on extended coordinates, + then let communicate_extended_output map nall->nloc. + + Parameters + ---------- + model + The model (CM instance). Must have ``atomic_model.forward_common_atomic``. + kk + The output key (e.g. "energy"). + vdef + The output variable definition. + extended_coord + Extended coordinates. Shape: [nf, nall, 3]. + extended_atype + Extended atom types. Shape: [nf, nall]. + nlist + Neighbor list. Shape: [nf, nloc, nsel]. + mapping + Mapping from extended to local. Shape: [nf, nall] or None. + fparam + Frame parameters. Shape: [nf, nfp] or None. + aparam + Atomic parameters. Shape: [nf, nloc, nap] or None. + create_graph + Whether to create graph for higher-order derivatives. + + Returns + ------- + torch.Tensor + Hessian on extended coordinates. Shape: [nf, *def, nall, 3, nall, 3]. + """ + nf, nall, _ = extended_coord.shape + vsize = math.prod(vdef.shape) + coord_flat = extended_coord.reshape(nf, nall * 3) + hessians = [] + for ii in range(nf): + for ci in range(vsize): + wrapper = _WrapperForwardEnergy( + model, + kk, + ci, + nall, + extended_atype[ii], + nlist[ii], + mapping[ii] if mapping is not None else None, + fparam[ii] if fparam is not None else None, + aparam[ii] if aparam is not None else None, + ) + hess = torch.autograd.functional.hessian( + wrapper, + coord_flat[ii], + create_graph=create_graph, + ) + hessians.append(hess) + # [nf * vsize, nall*3, nall*3] -> [nf, *vshape, nall, 3, nall, 3] + result = torch.stack(hessians).reshape(nf, *vdef.shape, nall, 3, nall, 3) + return result + + +class _WrapperForwardEnergy: + """Callable wrapper for torch.autograd.functional.hessian. + + Given flattened extended coordinates, recomputes the reduced energy + for one frame and one output component. + """ + + def __init__( + self, + model: Any, + kk: str, + ci: int, + nall: int, + atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> None: + self.model = model + self.kk = kk + self.ci = ci + self.nall = nall + self.atype = atype + self.nlist = nlist + self.mapping = mapping + self.fparam = fparam + self.aparam = aparam + + def __call__(self, coord_flat: torch.Tensor) -> torch.Tensor: + """Compute scalar reduced energy for one frame, one component. + + Parameters + ---------- + coord_flat + Flattened extended coordinates for one frame. Shape: [nall * 3]. + + Returns + ------- + torch.Tensor + Scalar energy component. + """ + cc_3d = coord_flat.reshape(1, self.nall, 3) + atomic_ret = self.model.atomic_model.forward_common_atomic( + cc_3d, + self.atype.unsqueeze(0), + self.nlist.unsqueeze(0), + mapping=self.mapping.unsqueeze(0) if self.mapping is not None else None, + fparam=self.fparam.unsqueeze(0) if self.fparam is not None else None, + aparam=self.aparam.unsqueeze(0) if self.aparam is not None else None, + ) + # atomic_ret[kk]: [1, nloc, *def] + atom_energy = atomic_ret[self.kk][0] # [nloc, *def] + energy_redu = atom_energy.sum(dim=0).reshape(-1)[self.ci] + return energy_redu + + def make_model( T_AtomicModel: type[BaseAtomicModel], T_Bases: tuple[type, ...] = (), @@ -81,7 +218,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, ) - return fit_output_to_model_output( + model_ret = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), extended_coord, @@ -89,5 +226,26 @@ def forward_common_atomic( create_graph=self.training, mask=atomic_ret.get("mask"), ) + # Hessian computation (mirrors JAX's forward_common_atomic). + # Produces hessian on extended coords [nf, *def, nall, 3, nall, 3], + # then communicate_extended_output maps it to nloc x nloc. + aod = self.atomic_output_def() + for kk in aod.keys(): + vdef = aod[kk] + if vdef.reducible and vdef.r_hessian: + kk_hess = get_hessian_name(kk) + model_ret[kk_hess] = _cal_hessian_ext( + self, + kk, + vdef, + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + create_graph=self.training, + ) + return model_ret return CM diff --git a/source/tests/consistent/model/test_ener_hessian.py b/source/tests/consistent/model/test_ener_hessian.py index 1181cac114..fa1bf19a51 100644 --- a/source/tests/consistent/model/test_ener_hessian.py +++ b/source/tests/consistent/model/test_ener_hessian.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_PT_EXPT, CommonTest, @@ -30,6 +31,11 @@ from deepmd.pt_expt.model import EnergyModel as EnergyModelPTExpt else: EnergyModelPTExpt = None +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None from deepmd.utils.argcheck import ( model_args, ) @@ -70,7 +76,7 @@ def data(self) -> dict: dp_class = EnergyModelDP pt_class = EnergyModelPT pt_expt_class = EnergyModelPTExpt - jax_class = None + jax_class = EnergyModelJAX pd_class = None args = model_args() @@ -83,12 +89,12 @@ def skip_dp(self) -> bool: return True @property - def skip_jax(self) -> bool: + def skip_pd(self) -> bool: return True @property - def skip_pd(self) -> bool: - return True + def skip_jax(self) -> bool: + return not INSTALLED_JAX @property def skip_array_api_strict(self) -> bool: @@ -103,6 +109,8 @@ def get_reference_backend(self): return self.RefBackend.PT if not self.skip_pt_expt and self.pt_expt_class is not None: return self.RefBackend.PT_EXPT + if not self.skip_jax: + return self.RefBackend.JAX raise ValueError("No available reference") def pass_data_to_cls(self, cls, data) -> Any: @@ -116,6 +124,8 @@ def pass_data_to_cls(self, cls, data) -> Any: elif cls is EnergyModelPTExpt: dp_model = get_model_dp(data) model = EnergyModelPTExpt.deserialize(dp_model.serialize()) + elif cls is EnergyModelJAX: + model = get_model_jax(data) else: model = cls(**data) model.enable_hessian() @@ -190,10 +200,20 @@ def eval_pt_expt(self, pt_expt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend in { self.RefBackend.PT, self.RefBackend.PT_EXPT, + self.RefBackend.JAX, }: return ( ret["energy"].ravel(), @@ -239,6 +259,21 @@ def test_pt_expt_consistent_with_ref(self) -> None: for rr1, rr2 in zip(ret1, ret2, strict=True): np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + def test_jax_consistent_with_ref(self) -> None: + """Test JAX consistent with reference, re-enabling hessian after deserialize.""" + if self.skip_jax or self.jax_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.JAX: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self._enable_hessian_on(self.jax_class.deserialize(data1)) + ret2 = self.eval_jax(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.JAX) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + def test_pt_self_consistent(self) -> None: """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" self.skipTest("Hessian state is not serialized") @@ -247,6 +282,14 @@ def test_pt_expt_self_consistent(self) -> None: """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" self.skipTest("Hessian state is not serialized") + def test_jax_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + def test_dp_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt_expt/model/test_make_hessian_model.py b/source/tests/pt_expt/model/test_ener_hessian_model.py similarity index 61% rename from source/tests/pt_expt/model/test_make_hessian_model.py rename to source/tests/pt_expt/model/test_ener_hessian_model.py index 7737c95786..8cd3a8ddbd 100644 --- a/source/tests/pt_expt/model/test_make_hessian_model.py +++ b/source/tests/pt_expt/model/test_ener_hessian_model.py @@ -1,7 +1,21 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pt_expt hessian model. + +Verifies that the autograd-based hessian (second derivative of energy +w.r.t. coordinates) produced by ``EnergyModel.enable_hessian()`` matches +a central finite-difference reference. + +Tested via ``forward()``, the user-facing call interface. + +Parametrized over ``nv`` (number of energy components) to cover both +single-component (nv=1) and multi-component (nv=2) outputs. +""" + +import typing import unittest import numpy as np +import pytest import torch from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA @@ -15,7 +29,6 @@ ) from deepmd.pt_expt.model import ( EnergyModel, - make_hessian_model, ) from deepmd.pt_expt.utils import ( env, @@ -35,6 +48,11 @@ def to_numpy_array(xx): def finite_hessian(f, x, delta=1e-6): + """Compute hessian by central finite difference. + + Uses the 4-point stencil: + (f(+d,+d) + f(-d,-d) - f(+d,-d) - f(-d,+d)) / (4*d^2). + """ in_shape = x.shape assert len(in_shape) == 1 y0 = f(x) @@ -62,16 +80,59 @@ def finite_hessian(f, x, delta=1e-6): return res -class HessianTest: - def test( - self, - ) -> None: - # setup test case +class TestHessianModel: + """Test ``EnergyModel.enable_hessian()`` against finite-difference hessian. + + Checks: + 1. Energy from the hessian-enabled model matches the plain model. + 2. The analytical hessian agrees with the finite-difference hessian + to 6 decimal places. + 3. The output definitions correctly reflect ``r_hessian=True`` and + contain the ``DERV_R_DERV_R`` category. + """ + + nf = 2 + nloc = 3 + rcut = 4.0 + rcut_smth = 3.0 + sel: typing.ClassVar[list[int]] = [10, 10] + nt = 2 + + def _build_models(self, nv: int) -> None: + """Build hessian-enabled and plain models with the given nv.""" + torch.manual_seed(2) + type_map = ["foo", "bar"] + ds_dp = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ) + ft_dp = DPInvarFitting( + "energy", + self.nt, + ds_dp.get_dim_out(), + nv, + mixed_types=ds_dp.mixed_types(), + numb_fparam=2, + numb_aparam=3, + neuron=[4, 4, 4], + ) + md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=type_map) + serialized = md_dp.serialize() + self.model_hess = EnergyModel.deserialize(serialized).to(env.DEVICE) + self.model_hess.enable_hessian() + self.model_valu = EnergyModel.deserialize(serialized).to(env.DEVICE) + + @pytest.mark.parametrize("nv", [1, 2]) # number of energy components + def test_hessian(self, nv) -> None: + """Analytical hessian from forward() must match finite-difference hessian.""" + self._build_models(nv) places = 6 delta = 1e-3 natoms = self.nloc nf = self.nf - nv = self.nv generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) cell0 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE, generator=generator) cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * torch.eye(3, device=env.DEVICE) @@ -101,21 +162,21 @@ def test( aparam = torch.rand( [nf, natoms, nap], dtype=dtype, device=env.DEVICE, generator=generator ) - # forward hess and value models + # forward() is the user-facing call interface of EnergyModel. # pt_expt requires coord to have requires_grad=True for autograd-based - # force/virial computation in forward_common + # force/virial/hessian computation. coord = coord.requires_grad_(True) - ret_dict0 = self.model_hess.forward_common( + ret_dict0 = self.model_hess.forward( coord, atype, box=cell, fparam=fparam, aparam=aparam ) - ret_dict1 = self.model_valu.forward_common( + ret_dict1 = self.model_valu.forward( coord, atype, box=cell, fparam=fparam, aparam=aparam ) - # compare hess and value models + # energy from hessian model must match the plain model torch.testing.assert_close(ret_dict0["energy"], ret_dict1["energy"]) - ana_hess = ret_dict0["energy_derv_r_derv_r"] + ana_hess = ret_dict0["hessian"] - # compute finite difference + # compute finite-difference hessian as reference fnt_hess = [] for ii in range(nf): @@ -123,117 +184,36 @@ def np_infer( xx, ): xx_t = to_torch_array(xx).unsqueeze(0).requires_grad_(True) - ret = self.model_valu.forward_common( + ret = self.model_valu.forward( xx_t, atype[ii].unsqueeze(0), box=cell[ii].unsqueeze(0), fparam=fparam[ii].unsqueeze(0), aparam=aparam[ii].unsqueeze(0), ) - # detach ret = {kk: to_numpy_array(ret[kk]) for kk in ret} return ret def ff(xx): - return np_infer(xx)["energy_redu"] + return np_infer(xx)["energy"] xx = to_numpy_array(coord[ii]) - fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) + fnt_hess.append(finite_hessian(ff, xx, delta=delta)) - # compare finite difference with autodiff - fnt_hess = np.stack(fnt_hess).reshape([nf, nv, natoms * 3, natoms * 3]) - np.testing.assert_almost_equal( - fnt_hess, to_numpy_array(ana_hess), decimal=places - ) - - -class TestDPModel(unittest.TestCase, HessianTest): - def setUp(self) -> None: - torch.manual_seed(2) - self.nf = 2 - self.nloc = 3 - self.rcut = 4.0 - self.rcut_smth = 3.0 - self.sel = [10, 10] - self.nt = 2 - self.nv = 2 - type_map = ["foo", "bar"] - # Build dpmodel first, then deserialize into pt_expt - ds_dp = DPDescrptSeA( - self.rcut, - self.rcut_smth, - self.sel, - neuron=[2, 4, 8], - axis_neuron=2, - ) - ft_dp = DPInvarFitting( - "energy", - self.nt, - ds_dp.get_dim_out(), - self.nv, - mixed_types=ds_dp.mixed_types(), - numb_fparam=2, - numb_aparam=3, - neuron=[4, 4, 4], - ) - md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=type_map) - serialized = md_dp.serialize() - # Create hessian model via make_hessian_model - HessEnergyModel = make_hessian_model(EnergyModel) - self.model_hess = HessEnergyModel.deserialize(serialized).to(env.DEVICE) - self.model_hess.requires_hessian("energy") - # Create value model (no hessian) - self.model_valu = EnergyModel.deserialize(serialized).to(env.DEVICE) + ana_hess_np = to_numpy_array(ana_hess) + fnt_hess = np.stack(fnt_hess).reshape(ana_hess_np.shape) + np.testing.assert_almost_equal(fnt_hess, ana_hess_np, decimal=places) def test_output_def(self) -> None: - self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) - self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) - self.assertTrue(self.model_hess.model_output_def()["energy"].r_hessian) - self.assertEqual( - self.model_hess.model_output_def()["energy_derv_r_derv_r"].category, - OutputVariableCategory.DERV_R_DERV_R, - ) - - -class TestEnableHessian(unittest.TestCase, HessianTest): - """Test hessian via enable_hessian() method.""" - - def setUp(self) -> None: - torch.manual_seed(2) - self.nf = 2 - self.nloc = 3 - self.rcut = 4.0 - self.rcut_smth = 3.0 - self.sel = [10, 10] - self.nt = 2 - self.nv = 1 - type_map = ["foo", "bar"] - ds_dp = DPDescrptSeA( - self.rcut, - self.rcut_smth, - self.sel, - neuron=[2, 4, 8], - axis_neuron=2, + """Output defs: r_hessian flag and DERV_R_DERV_R category.""" + self._build_models(nv=1) + assert self.model_hess.atomic_output_def()["energy"].r_hessian + assert not self.model_valu.atomic_output_def()["energy"].r_hessian + assert self.model_hess.model_output_def()["energy"].r_hessian + assert ( + self.model_hess.model_output_def()["energy_derv_r_derv_r"].category + == OutputVariableCategory.DERV_R_DERV_R ) - ft_dp = DPInvarFitting( - "energy", - self.nt, - ds_dp.get_dim_out(), - self.nv, - mixed_types=ds_dp.mixed_types(), - numb_fparam=2, - numb_aparam=3, - neuron=[4, 4, 4], - ) - md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=type_map) - serialized = md_dp.serialize() - self.model_hess = EnergyModel.deserialize(serialized).to(env.DEVICE) - self.model_hess.enable_hessian() - self.model_valu = EnergyModel.deserialize(serialized).to(env.DEVICE) - - def test_output_def(self) -> None: - self.assertTrue(self.model_hess.atomic_output_def()["energy"].r_hessian) - self.assertFalse(self.model_valu.atomic_output_def()["energy"].r_hessian) if __name__ == "__main__": From 7f04df800fac071e9f69b045b87683e66b4fb2d4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 23:01:58 +0800 Subject: [PATCH 3/5] reduce delta in finite diff --- source/tests/pt_expt/model/test_ener_hessian_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt_expt/model/test_ener_hessian_model.py b/source/tests/pt_expt/model/test_ener_hessian_model.py index 8cd3a8ddbd..a9f1b2a8ac 100644 --- a/source/tests/pt_expt/model/test_ener_hessian_model.py +++ b/source/tests/pt_expt/model/test_ener_hessian_model.py @@ -130,7 +130,7 @@ def test_hessian(self, nv) -> None: """Analytical hessian from forward() must match finite-difference hessian.""" self._build_models(nv) places = 6 - delta = 1e-3 + delta = 5e-4 natoms = self.nloc nf = self.nf generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) From 778a46757bedaf8eeaa4f84467ad7f415dd37818 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 23:42:28 +0800 Subject: [PATCH 4/5] fix --- source/tests/pt_expt/model/test_ener_hessian_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt_expt/model/test_ener_hessian_model.py b/source/tests/pt_expt/model/test_ener_hessian_model.py index a9f1b2a8ac..8e0e812ef7 100644 --- a/source/tests/pt_expt/model/test_ener_hessian_model.py +++ b/source/tests/pt_expt/model/test_ener_hessian_model.py @@ -182,6 +182,7 @@ def test_hessian(self, nv) -> None: def np_infer( xx, + ii=ii, ): xx_t = to_torch_array(xx).unsqueeze(0).requires_grad_(True) ret = self.model_valu.forward( From 0560d465fc43cbbf2f753736e8b9e757ac947cdf Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 23:45:23 +0800 Subject: [PATCH 5/5] fix --- deepmd/pt_expt/model/ener_model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 07cdfc5a54..0b10d30760 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, ) @@ -40,8 +41,12 @@ def __init__( self._hessian_enabled = False def enable_hessian(self) -> None: + if self._hessian_enabled: + return self.__class__ = make_hessian_model(type(self)) - self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.hess_fitting_def = copy.deepcopy( + super(type(self), self).atomic_output_def() + ) self.requires_hessian("energy") self._hessian_enabled = True