Skip to content

Commit 73bb1b7

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): implementing energy hessian model (#5287)
### Summary Add hessian (second derivative of energy w.r.t. coordinates) support to the pt_expt backend, mirroring the JAX backend's approach. - Compute hessian on extended coordinates in forward_common_atomic (via torch.autograd.functional.hessian), then let dpmodel's communicate_extended_output naturally map from nall×nall to nloc×nloc - make_hessian_model only overrides atomic_output_def() to set r_hessian=True — no forward_common override or r_hessian toggle hack needed - Hessian is enabled at runtime via EnergyModel.enable_hessian() and returned through the user-facing forward() interface ### Changed files - deepmd/pt_expt/model/make_model.py — add _cal_hessian_ext and _WrapperForwardEnergy for hessian computation in forward_common_atomic - deepmd/pt_expt/model/make_hessian_model.py — minimal wrapper: __init__, requires_hessian, atomic_output_def - deepmd/pt_expt/model/ener_model.py — enable_hessian(), hessian output in forward() and translated_output_def() - deepmd/pt_expt/model/__init__.py — export make_hessian_model - source/tests/pt_expt/model/test_ener_hessian_model.py — unit test: autograd hessian vs finite-difference, parametrized over nv={1,2} - source/tests/consistent/model/test_ener_hessian.py — cross-backend consistency test covering PT, pt_expt, and JAX <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Adds Hessian (second-order derivative) support for energy models. * Provides a mechanism to enable Hessian calculations per model instance. * When enabled, Hessian tensors are included in model prediction outputs. * **Tests** * Added comprehensive multi-backend validation of Hessian functionality. * Added unit tests comparing analytical Hessians to finite-difference references. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 8c5cb86 commit 73bb1b7

File tree

6 files changed

+758
-1
lines changed

6 files changed

+758
-1
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import copy
3+
from typing import (
4+
Any,
5+
)
6+
7+
from deepmd.dpmodel.output_def import (
8+
FittingOutputDef,
9+
)
10+
11+
12+
def make_hessian_model(T_Model: type) -> type:
13+
"""Make a model that can compute Hessian.
14+
15+
With the JAX-mirrored approach, hessian is computed in
16+
``forward_common_atomic`` (in make_model.py) on extended coordinates.
17+
This wrapper only needs to override ``atomic_output_def()`` to set
18+
``r_hessian=True``, and ``communicate_extended_output`` in dpmodel
19+
naturally maps it from nall to nloc.
20+
21+
Parameters
22+
----------
23+
T_Model
24+
The model. Should provide the ``atomic_output_def`` method.
25+
26+
Returns
27+
-------
28+
The model that computes hessian.
29+
30+
"""
31+
32+
class CM(T_Model):
33+
def __init__(
34+
self,
35+
*args: Any,
36+
**kwargs: Any,
37+
) -> None:
38+
super().__init__(
39+
*args,
40+
**kwargs,
41+
)
42+
self.hess_fitting_def = copy.deepcopy(super().atomic_output_def())
43+
44+
def requires_hessian(
45+
self,
46+
keys: str | list[str],
47+
) -> None:
48+
"""Set which output variable(s) requires hessian."""
49+
if isinstance(keys, str):
50+
keys = [keys]
51+
for kk in self.hess_fitting_def.keys():
52+
if kk in keys:
53+
self.hess_fitting_def[kk].r_hessian = True
54+
55+
def atomic_output_def(self) -> FittingOutputDef:
56+
"""Get the fitting output def."""
57+
return self.hess_fitting_def
58+
59+
return CM

deepmd/pt_expt/model/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.model.make_hessian_model import (
3+
make_hessian_model,
4+
)
5+
26
from .dipole_model import (
37
DipoleModel,
48
)
@@ -33,4 +37,5 @@
3337
"PolarModel",
3438
"PropertyModel",
3539
"get_model",
40+
"make_hessian_model",
3641
]

deepmd/pt_expt/model/ener_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import copy
23
from typing import (
34
Any,
45
)
@@ -14,6 +15,9 @@
1415
from deepmd.dpmodel.model.dp_model import (
1516
DPModelCommon,
1617
)
18+
from deepmd.dpmodel.model.make_hessian_model import (
19+
make_hessian_model,
20+
)
1721

1822
from .make_model import (
1923
make_model,
@@ -34,6 +38,17 @@ def __init__(
3438
) -> None:
3539
DPModelCommon.__init__(self)
3640
DPEnergyModel_.__init__(self, *args, **kwargs)
41+
self._hessian_enabled = False
42+
43+
def enable_hessian(self) -> None:
44+
if self._hessian_enabled:
45+
return
46+
self.__class__ = make_hessian_model(type(self))
47+
self.hess_fitting_def = copy.deepcopy(
48+
super(type(self), self).atomic_output_def()
49+
)
50+
self.requires_hessian("energy")
51+
self._hessian_enabled = True
3752

3853
def forward(
3954
self,
@@ -63,6 +78,8 @@ def forward(
6378
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2)
6479
if "mask" in model_ret:
6580
model_predict["mask"] = model_ret["mask"]
81+
if self.atomic_output_def()["energy"].r_hessian:
82+
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3)
6683
return model_predict
6784

6885
def forward_lower(
@@ -115,6 +132,8 @@ def translated_output_def(self) -> dict[str, Any]:
115132
output_def["atom_virial"].squeeze(-2)
116133
if "mask" in out_def_data:
117134
output_def["mask"] = out_def_data["mask"]
135+
if self.atomic_output_def()["energy"].r_hessian:
136+
output_def["hessian"] = out_def_data["energy_derv_r_derv_r"]
118137
return output_def
119138

120139
def forward_lower_exportable(

deepmd/pt_expt/model/make_model.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import math
23
from typing import (
34
Any,
45
)
@@ -8,10 +9,16 @@
89
make_fx,
910
)
1011

12+
from deepmd.dpmodel import (
13+
get_hessian_name,
14+
)
1115
from deepmd.dpmodel.atomic_model.base_atomic_model import (
1216
BaseAtomicModel,
1317
)
1418
from deepmd.dpmodel.model.make_model import make_model as make_model_dp
19+
from deepmd.dpmodel.output_def import (
20+
OutputVariableDef,
21+
)
1522
from deepmd.pt_expt.common import (
1623
torch_module,
1724
)
@@ -21,6 +28,136 @@
2128
)
2229

2330

31+
def _cal_hessian_ext(
32+
model: Any,
33+
kk: str,
34+
vdef: OutputVariableDef,
35+
extended_coord: torch.Tensor,
36+
extended_atype: torch.Tensor,
37+
nlist: torch.Tensor,
38+
mapping: torch.Tensor | None,
39+
fparam: torch.Tensor | None,
40+
aparam: torch.Tensor | None,
41+
create_graph: bool = False,
42+
) -> torch.Tensor:
43+
"""Compute hessian of reduced output w.r.t. extended coordinates.
44+
45+
Mirrors the JAX approach: compute hessian on extended coordinates,
46+
then let communicate_extended_output map nall->nloc.
47+
48+
Parameters
49+
----------
50+
model
51+
The model (CM instance). Must have ``atomic_model.forward_common_atomic``.
52+
kk
53+
The output key (e.g. "energy").
54+
vdef
55+
The output variable definition.
56+
extended_coord
57+
Extended coordinates. Shape: [nf, nall, 3].
58+
extended_atype
59+
Extended atom types. Shape: [nf, nall].
60+
nlist
61+
Neighbor list. Shape: [nf, nloc, nsel].
62+
mapping
63+
Mapping from extended to local. Shape: [nf, nall] or None.
64+
fparam
65+
Frame parameters. Shape: [nf, nfp] or None.
66+
aparam
67+
Atomic parameters. Shape: [nf, nloc, nap] or None.
68+
create_graph
69+
Whether to create graph for higher-order derivatives.
70+
71+
Returns
72+
-------
73+
torch.Tensor
74+
Hessian on extended coordinates. Shape: [nf, *def, nall, 3, nall, 3].
75+
"""
76+
nf, nall, _ = extended_coord.shape
77+
vsize = math.prod(vdef.shape)
78+
coord_flat = extended_coord.reshape(nf, nall * 3)
79+
hessians = []
80+
for ii in range(nf):
81+
for ci in range(vsize):
82+
wrapper = _WrapperForwardEnergy(
83+
model,
84+
kk,
85+
ci,
86+
nall,
87+
extended_atype[ii],
88+
nlist[ii],
89+
mapping[ii] if mapping is not None else None,
90+
fparam[ii] if fparam is not None else None,
91+
aparam[ii] if aparam is not None else None,
92+
)
93+
hess = torch.autograd.functional.hessian(
94+
wrapper,
95+
coord_flat[ii],
96+
create_graph=create_graph,
97+
)
98+
hessians.append(hess)
99+
# [nf * vsize, nall*3, nall*3] -> [nf, *vshape, nall, 3, nall, 3]
100+
result = torch.stack(hessians).reshape(nf, *vdef.shape, nall, 3, nall, 3)
101+
return result
102+
103+
104+
class _WrapperForwardEnergy:
105+
"""Callable wrapper for torch.autograd.functional.hessian.
106+
107+
Given flattened extended coordinates, recomputes the reduced energy
108+
for one frame and one output component.
109+
"""
110+
111+
def __init__(
112+
self,
113+
model: Any,
114+
kk: str,
115+
ci: int,
116+
nall: int,
117+
atype: torch.Tensor,
118+
nlist: torch.Tensor,
119+
mapping: torch.Tensor | None,
120+
fparam: torch.Tensor | None,
121+
aparam: torch.Tensor | None,
122+
) -> None:
123+
self.model = model
124+
self.kk = kk
125+
self.ci = ci
126+
self.nall = nall
127+
self.atype = atype
128+
self.nlist = nlist
129+
self.mapping = mapping
130+
self.fparam = fparam
131+
self.aparam = aparam
132+
133+
def __call__(self, coord_flat: torch.Tensor) -> torch.Tensor:
134+
"""Compute scalar reduced energy for one frame, one component.
135+
136+
Parameters
137+
----------
138+
coord_flat
139+
Flattened extended coordinates for one frame. Shape: [nall * 3].
140+
141+
Returns
142+
-------
143+
torch.Tensor
144+
Scalar energy component.
145+
"""
146+
cc_3d = coord_flat.reshape(1, self.nall, 3)
147+
atomic_ret = self.model.atomic_model.forward_common_atomic(
148+
cc_3d,
149+
self.atype.unsqueeze(0),
150+
self.nlist.unsqueeze(0),
151+
mapping=self.mapping.unsqueeze(0) if self.mapping is not None else None,
152+
fparam=self.fparam.unsqueeze(0) if self.fparam is not None else None,
153+
aparam=self.aparam.unsqueeze(0) if self.aparam is not None else None,
154+
)
155+
# atomic_ret[kk]: [1, nloc, *def]
156+
atom_energy = atomic_ret[self.kk][0] # [nloc, *def]
157+
energy_redu = atom_energy.sum(dim=0).reshape(-1)[self.ci]
158+
return energy_redu
159+
160+
24161
def make_model(
25162
T_AtomicModel: type[BaseAtomicModel],
26163
T_Bases: tuple[type, ...] = (),
@@ -84,14 +221,35 @@ def forward_common_atomic(
84221
fparam=fparam,
85222
aparam=aparam,
86223
)
87-
return fit_output_to_model_output(
224+
model_ret = fit_output_to_model_output(
88225
atomic_ret,
89226
self.atomic_output_def(),
90227
extended_coord,
91228
do_atomic_virial=do_atomic_virial,
92229
create_graph=self.training,
93230
mask=atomic_ret.get("mask"),
94231
)
232+
# Hessian computation (mirrors JAX's forward_common_atomic).
233+
# Produces hessian on extended coords [nf, *def, nall, 3, nall, 3],
234+
# then communicate_extended_output maps it to nloc x nloc.
235+
aod = self.atomic_output_def()
236+
for kk in aod.keys():
237+
vdef = aod[kk]
238+
if vdef.reducible and vdef.r_hessian:
239+
kk_hess = get_hessian_name(kk)
240+
model_ret[kk_hess] = _cal_hessian_ext(
241+
self,
242+
kk,
243+
vdef,
244+
extended_coord,
245+
extended_atype,
246+
nlist,
247+
mapping,
248+
fparam,
249+
aparam,
250+
create_graph=self.training,
251+
)
252+
return model_ret
95253

96254
def forward_common_lower_exportable(
97255
self,

0 commit comments

Comments
 (0)