diff --git a/deepmd/dpmodel/model/make_hessian_model.py b/deepmd/dpmodel/model/make_hessian_model.py new file mode 100644 index 0000000000..10e46ebfc8 --- /dev/null +++ b/deepmd/dpmodel/model/make_hessian_model.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + Any, +) + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, +) + + +def make_hessian_model(T_Model: type) -> type: + """Make a model that can compute Hessian. + + With the JAX-mirrored approach, hessian is computed in + ``forward_common_atomic`` (in make_model.py) on extended coordinates. + This wrapper only needs to override ``atomic_output_def()`` to set + ``r_hessian=True``, and ``communicate_extended_output`` in dpmodel + naturally maps it from nall to nloc. + + Parameters + ---------- + T_Model + The model. Should provide the ``atomic_output_def`` method. + + Returns + ------- + The model that computes hessian. + + """ + + class CM(T_Model): + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + **kwargs, + ) + self.hess_fitting_def = copy.deepcopy(super().atomic_output_def()) + + def requires_hessian( + self, + keys: str | list[str], + ) -> None: + """Set which output variable(s) requires hessian.""" + if isinstance(keys, str): + keys = [keys] + for kk in self.hess_fitting_def.keys(): + if kk in keys: + self.hess_fitting_def[kk].r_hessian = True + + def atomic_output_def(self) -> FittingOutputDef: + """Get the fitting output def.""" + return self.hess_fitting_def + + return CM diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py index 7b3f7cdeab..7197e39634 100644 --- a/deepmd/pt_expt/model/__init__.py +++ b/deepmd/pt_expt/model/__init__.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.model.make_hessian_model import ( + make_hessian_model, +) + from .dipole_model import ( DipoleModel, ) @@ -33,4 +37,5 @@ "PolarModel", "PropertyModel", "get_model", + "make_hessian_model", ] diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 53b22ab705..beb91c4ec4 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy from typing import ( Any, ) @@ -14,6 +15,9 @@ from deepmd.dpmodel.model.dp_model import ( DPModelCommon, ) +from deepmd.dpmodel.model.make_hessian_model import ( + make_hessian_model, +) from .make_model import ( make_model, @@ -34,6 +38,17 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + + def enable_hessian(self) -> None: + if self._hessian_enabled: + return + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = copy.deepcopy( + super(type(self), self).atomic_output_def() + ) + self.requires_hessian("energy") + self._hessian_enabled = True def forward( self, @@ -63,6 +78,8 @@ def forward( model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] + if self.atomic_output_def()["energy"].r_hessian: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3) return model_predict def forward_lower( @@ -115,6 +132,8 @@ def translated_output_def(self) -> dict[str, Any]: output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] + if self.atomic_output_def()["energy"].r_hessian: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def def forward_lower_exportable( diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 4baf3f5c7a..02e67d0f27 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import math from typing import ( Any, ) @@ -8,10 +9,16 @@ make_fx, ) +from deepmd.dpmodel import ( + get_hessian_name, +) from deepmd.dpmodel.atomic_model.base_atomic_model import ( BaseAtomicModel, ) from deepmd.dpmodel.model.make_model import make_model as make_model_dp +from deepmd.dpmodel.output_def import ( + OutputVariableDef, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -21,6 +28,136 @@ ) +def _cal_hessian_ext( + model: Any, + kk: str, + vdef: OutputVariableDef, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + create_graph: bool = False, +) -> torch.Tensor: + """Compute hessian of reduced output w.r.t. extended coordinates. + + Mirrors the JAX approach: compute hessian on extended coordinates, + then let communicate_extended_output map nall->nloc. + + Parameters + ---------- + model + The model (CM instance). Must have ``atomic_model.forward_common_atomic``. + kk + The output key (e.g. "energy"). + vdef + The output variable definition. + extended_coord + Extended coordinates. Shape: [nf, nall, 3]. + extended_atype + Extended atom types. Shape: [nf, nall]. + nlist + Neighbor list. Shape: [nf, nloc, nsel]. + mapping + Mapping from extended to local. Shape: [nf, nall] or None. + fparam + Frame parameters. Shape: [nf, nfp] or None. + aparam + Atomic parameters. Shape: [nf, nloc, nap] or None. + create_graph + Whether to create graph for higher-order derivatives. + + Returns + ------- + torch.Tensor + Hessian on extended coordinates. Shape: [nf, *def, nall, 3, nall, 3]. + """ + nf, nall, _ = extended_coord.shape + vsize = math.prod(vdef.shape) + coord_flat = extended_coord.reshape(nf, nall * 3) + hessians = [] + for ii in range(nf): + for ci in range(vsize): + wrapper = _WrapperForwardEnergy( + model, + kk, + ci, + nall, + extended_atype[ii], + nlist[ii], + mapping[ii] if mapping is not None else None, + fparam[ii] if fparam is not None else None, + aparam[ii] if aparam is not None else None, + ) + hess = torch.autograd.functional.hessian( + wrapper, + coord_flat[ii], + create_graph=create_graph, + ) + hessians.append(hess) + # [nf * vsize, nall*3, nall*3] -> [nf, *vshape, nall, 3, nall, 3] + result = torch.stack(hessians).reshape(nf, *vdef.shape, nall, 3, nall, 3) + return result + + +class _WrapperForwardEnergy: + """Callable wrapper for torch.autograd.functional.hessian. + + Given flattened extended coordinates, recomputes the reduced energy + for one frame and one output component. + """ + + def __init__( + self, + model: Any, + kk: str, + ci: int, + nall: int, + atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> None: + self.model = model + self.kk = kk + self.ci = ci + self.nall = nall + self.atype = atype + self.nlist = nlist + self.mapping = mapping + self.fparam = fparam + self.aparam = aparam + + def __call__(self, coord_flat: torch.Tensor) -> torch.Tensor: + """Compute scalar reduced energy for one frame, one component. + + Parameters + ---------- + coord_flat + Flattened extended coordinates for one frame. Shape: [nall * 3]. + + Returns + ------- + torch.Tensor + Scalar energy component. + """ + cc_3d = coord_flat.reshape(1, self.nall, 3) + atomic_ret = self.model.atomic_model.forward_common_atomic( + cc_3d, + self.atype.unsqueeze(0), + self.nlist.unsqueeze(0), + mapping=self.mapping.unsqueeze(0) if self.mapping is not None else None, + fparam=self.fparam.unsqueeze(0) if self.fparam is not None else None, + aparam=self.aparam.unsqueeze(0) if self.aparam is not None else None, + ) + # atomic_ret[kk]: [1, nloc, *def] + atom_energy = atomic_ret[self.kk][0] # [nloc, *def] + energy_redu = atom_energy.sum(dim=0).reshape(-1)[self.ci] + return energy_redu + + def make_model( T_AtomicModel: type[BaseAtomicModel], T_Bases: tuple[type, ...] = (), @@ -84,7 +221,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, ) - return fit_output_to_model_output( + model_ret = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), extended_coord, @@ -92,6 +229,27 @@ def forward_common_atomic( create_graph=self.training, mask=atomic_ret.get("mask"), ) + # Hessian computation (mirrors JAX's forward_common_atomic). + # Produces hessian on extended coords [nf, *def, nall, 3, nall, 3], + # then communicate_extended_output maps it to nloc x nloc. + aod = self.atomic_output_def() + for kk in aod.keys(): + vdef = aod[kk] + if vdef.reducible and vdef.r_hessian: + kk_hess = get_hessian_name(kk) + model_ret[kk_hess] = _cal_hessian_ext( + self, + kk, + vdef, + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + create_graph=self.training, + ) + return model_ret def forward_common_lower_exportable( self, diff --git a/source/tests/consistent/model/test_ener_hessian.py b/source/tests/consistent/model/test_ener_hessian.py new file mode 100644 index 0000000000..fa1bf19a51 --- /dev/null +++ b/source/tests/consistent/model/test_ener_hessian.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_PT_EXPT, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT +else: + EnergyModelPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.model import EnergyModel as EnergyModelPTExpt +else: + EnergyModelPTExpt = None +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None +from deepmd.utils.argcheck import ( + model_args, +) + + +class TestEnerHessian(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = None + dp_class = EnergyModelDP + pt_class = EnergyModelPT + pt_expt_class = EnergyModelPTExpt + jax_class = EnergyModelJAX + pd_class = None + args = model_args() + + @property + def skip_tf(self) -> bool: + return True + + @property + def skip_dp(self) -> bool: + return True + + @property + def skip_pd(self) -> bool: + return True + + @property + def skip_jax(self) -> bool: + return not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + return True + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can compute hessian. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT + if not self.skip_jax: + return self.RefBackend.JAX + raise ValueError("No available reference") + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class and enable hessian.""" + data = data.copy() + if cls is EnergyModelDP: + model = get_model_dp(data) + elif cls is EnergyModelPT: + model = get_model_pt(data) + model.atomic_model.out_bias.uniform_() + elif cls is EnergyModelPTExpt: + dp_model = get_model_dp(data) + model = EnergyModelPTExpt.deserialize(dp_model.serialize()) + elif cls is EnergyModelJAX: + model = get_model_jax(data) + else: + model = cls(**data) + model.enable_hessian() + return model + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + raise NotImplementedError("no TF in this test") + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_model( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + if backend in { + self.RefBackend.PT, + self.RefBackend.PT_EXPT, + self.RefBackend.JAX, + }: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["hessian"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") + + def _enable_hessian_on(self, obj): + """Enable hessian on a model object.""" + obj.enable_hessian() + return obj + + def test_pt_consistent_with_ref(self) -> None: + """Test PT consistent with reference, re-enabling hessian after deserialize.""" + if self.skip_pt: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self._enable_hessian_on(self.pt_class.deserialize(data1)) + ret2 = self.eval_pt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + + def test_pt_expt_consistent_with_ref(self) -> None: + """Test pt_expt consistent with reference, re-enabling hessian after deserialize.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT_EXPT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self._enable_hessian_on(self.pt_expt_class.deserialize(data1)) + ret2 = self.eval_pt_expt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT_EXPT) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + + def test_jax_consistent_with_ref(self) -> None: + """Test JAX consistent with reference, re-enabling hessian after deserialize.""" + if self.skip_jax or self.jax_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.JAX: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self._enable_hessian_on(self.jax_class.deserialize(data1)) + ret2 = self.eval_jax(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.JAX) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + + def test_pt_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + def test_pt_expt_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + def test_jax_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + def test_dp_self_consistent(self) -> None: + """Skip: hessian is a runtime flag, not preserved by serialize/deserialize.""" + self.skipTest("Hessian state is not serialized") + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/model/test_ener_hessian_model.py b/source/tests/pt_expt/model/test_ener_hessian_model.py new file mode 100644 index 0000000000..8e0e812ef7 --- /dev/null +++ b/source/tests/pt_expt/model/test_ener_hessian_model.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pt_expt hessian model. + +Verifies that the autograd-based hessian (second derivative of energy +w.r.t. coordinates) produced by ``EnergyModel.enable_hessian()`` matches +a central finite-difference reference. + +Tested via ``forward()``, the user-facing call interface. + +Parametrized over ``nv`` (number of energy components) to cover both +single-component (nv=1) and multi-component (nv=2) outputs. +""" + +import typing +import unittest + +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.dpmodel.model.ener_model import EnergyModel as DPEnergyModel +from deepmd.dpmodel.output_def import ( + OutputVariableCategory, +) +from deepmd.pt_expt.common import ( + to_torch_array, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + +dtype = torch.float64 + + +def to_numpy_array(xx): + if isinstance(xx, torch.Tensor): + return xx.detach().cpu().numpy() + return np.asarray(xx) + + +def finite_hessian(f, x, delta=1e-6): + """Compute hessian by central finite difference. + + Uses the 4-point stencil: + (f(+d,+d) + f(-d,-d) - f(+d,-d) - f(-d,+d)) / (4*d^2). + """ + in_shape = x.shape + assert len(in_shape) == 1 + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape + in_shape) + for iidx in np.ndindex(*in_shape): + for jidx in np.ndindex(*in_shape): + i0 = np.zeros(in_shape) + i1 = np.zeros(in_shape) + i2 = np.zeros(in_shape) + i3 = np.zeros(in_shape) + i0[iidx] += delta + i2[iidx] += delta + i1[iidx] -= delta + i3[iidx] -= delta + i0[jidx] += delta + i1[jidx] += delta + i2[jidx] -= delta + i3[jidx] -= delta + y0 = f(x + i0) + y1 = f(x + i1) + y2 = f(x + i2) + y3 = f(x + i3) + res[(Ellipsis, *iidx, *jidx)] = (y0 + y3 - y1 - y2) / (4 * delta**2.0) + return res + + +class TestHessianModel: + """Test ``EnergyModel.enable_hessian()`` against finite-difference hessian. + + Checks: + 1. Energy from the hessian-enabled model matches the plain model. + 2. The analytical hessian agrees with the finite-difference hessian + to 6 decimal places. + 3. The output definitions correctly reflect ``r_hessian=True`` and + contain the ``DERV_R_DERV_R`` category. + """ + + nf = 2 + nloc = 3 + rcut = 4.0 + rcut_smth = 3.0 + sel: typing.ClassVar[list[int]] = [10, 10] + nt = 2 + + def _build_models(self, nv: int) -> None: + """Build hessian-enabled and plain models with the given nv.""" + torch.manual_seed(2) + type_map = ["foo", "bar"] + ds_dp = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + neuron=[2, 4, 8], + axis_neuron=2, + ) + ft_dp = DPInvarFitting( + "energy", + self.nt, + ds_dp.get_dim_out(), + nv, + mixed_types=ds_dp.mixed_types(), + numb_fparam=2, + numb_aparam=3, + neuron=[4, 4, 4], + ) + md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=type_map) + serialized = md_dp.serialize() + self.model_hess = EnergyModel.deserialize(serialized).to(env.DEVICE) + self.model_hess.enable_hessian() + self.model_valu = EnergyModel.deserialize(serialized).to(env.DEVICE) + + @pytest.mark.parametrize("nv", [1, 2]) # number of energy components + def test_hessian(self, nv) -> None: + """Analytical hessian from forward() must match finite-difference hessian.""" + self._build_models(nv) + places = 6 + delta = 5e-4 + natoms = self.nloc + nf = self.nf + generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) + cell0 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE, generator=generator) + cell0 = 1.0 * (cell0 + cell0.T) + 5.0 * torch.eye(3, device=env.DEVICE) + cell1 = torch.rand([3, 3], dtype=dtype, device=env.DEVICE, generator=generator) + cell1 = 1.0 * (cell1 + cell1.T) + 5.0 * torch.eye(3, device=env.DEVICE) + cell = torch.stack([cell0, cell1]) + coord = torch.rand( + [nf, natoms, 3], dtype=dtype, device=env.DEVICE, generator=generator + ) + coord = torch.matmul(coord, cell) + cell = cell.view([nf, 9]) + coord = coord.view([nf, natoms * 3]) + atype = ( + torch.stack( + [ + torch.IntTensor([0, 0, 1]), + torch.IntTensor([1, 0, 1]), + ] + ) + .view([nf, natoms]) + .to(env.DEVICE) + ) + nfp, nap = 2, 3 + fparam = torch.rand( + [nf, nfp], dtype=dtype, device=env.DEVICE, generator=generator + ) + aparam = torch.rand( + [nf, natoms, nap], dtype=dtype, device=env.DEVICE, generator=generator + ) + # forward() is the user-facing call interface of EnergyModel. + # pt_expt requires coord to have requires_grad=True for autograd-based + # force/virial/hessian computation. + coord = coord.requires_grad_(True) + ret_dict0 = self.model_hess.forward( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + ret_dict1 = self.model_valu.forward( + coord, atype, box=cell, fparam=fparam, aparam=aparam + ) + # energy from hessian model must match the plain model + torch.testing.assert_close(ret_dict0["energy"], ret_dict1["energy"]) + ana_hess = ret_dict0["hessian"] + + # compute finite-difference hessian as reference + fnt_hess = [] + for ii in range(nf): + + def np_infer( + xx, + ii=ii, + ): + xx_t = to_torch_array(xx).unsqueeze(0).requires_grad_(True) + ret = self.model_valu.forward( + xx_t, + atype[ii].unsqueeze(0), + box=cell[ii].unsqueeze(0), + fparam=fparam[ii].unsqueeze(0), + aparam=aparam[ii].unsqueeze(0), + ) + ret = {kk: to_numpy_array(ret[kk]) for kk in ret} + return ret + + def ff(xx): + return np_infer(xx)["energy"] + + xx = to_numpy_array(coord[ii]) + fnt_hess.append(finite_hessian(ff, xx, delta=delta)) + + ana_hess_np = to_numpy_array(ana_hess) + fnt_hess = np.stack(fnt_hess).reshape(ana_hess_np.shape) + np.testing.assert_almost_equal(fnt_hess, ana_hess_np, decimal=places) + + def test_output_def(self) -> None: + """Output defs: r_hessian flag and DERV_R_DERV_R category.""" + self._build_models(nv=1) + assert self.model_hess.atomic_output_def()["energy"].r_hessian + assert not self.model_valu.atomic_output_def()["energy"].r_hessian + assert self.model_hess.model_output_def()["energy"].r_hessian + assert ( + self.model_hess.model_output_def()["energy_derv_r_derv_r"].category + == OutputVariableCategory.DERV_R_DERV_R + ) + + +if __name__ == "__main__": + unittest.main()