11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ import copy
23from typing import (
34 Any ,
45)
1415from deepmd .dpmodel .model .dp_model import (
1516 DPModelCommon ,
1617)
18+ from deepmd .dpmodel .model .make_hessian_model import (
19+ make_hessian_model ,
20+ )
1721
1822from .make_model import (
1923 make_model ,
@@ -34,6 +38,17 @@ def __init__(
3438 ) -> None :
3539 DPModelCommon .__init__ (self )
3640 DPEnergyModel_ .__init__ (self , * args , ** kwargs )
41+ self ._hessian_enabled = False
42+
43+ def enable_hessian (self ) -> None :
44+ if self ._hessian_enabled :
45+ return
46+ self .__class__ = make_hessian_model (type (self ))
47+ self .hess_fitting_def = copy .deepcopy (
48+ super (type (self ), self ).atomic_output_def ()
49+ )
50+ self .requires_hessian ("energy" )
51+ self ._hessian_enabled = True
3752
3853 def forward (
3954 self ,
@@ -63,6 +78,8 @@ def forward(
6378 model_predict ["atom_virial" ] = model_ret ["energy_derv_c" ].squeeze (- 2 )
6479 if "mask" in model_ret :
6580 model_predict ["mask" ] = model_ret ["mask" ]
81+ if self .atomic_output_def ()["energy" ].r_hessian :
82+ model_predict ["hessian" ] = model_ret ["energy_derv_r_derv_r" ].squeeze (- 3 )
6683 return model_predict
6784
6885 def forward_lower (
@@ -115,6 +132,8 @@ def translated_output_def(self) -> dict[str, Any]:
115132 output_def ["atom_virial" ].squeeze (- 2 )
116133 if "mask" in out_def_data :
117134 output_def ["mask" ] = out_def_data ["mask" ]
135+ if self .atomic_output_def ()["energy" ].r_hessian :
136+ output_def ["hessian" ] = out_def_data ["energy_derv_r_derv_r" ]
118137 return output_def
119138
120139 def forward_lower_exportable (
0 commit comments