@@ -33,6 +33,37 @@ def __init__(
3333 ) -> None :
3434 DPModelCommon .__init__ (self )
3535 DPEnergyModel_ .__init__ (self , * args , ** kwargs )
36+ self ._hessian_enabled = False
37+
38+ def enable_hessian (self ):
39+ raise NotImplementedError (
40+ "Hessian calculation is not implemented yet on PaddlePaddle platform."
41+ )
42+
43+ def get_observed_type_list (self ) -> list [str ]:
44+ """Get observed types (elements) of the model during data statistics.
45+
46+ Returns
47+ -------
48+ observed_type_list: a list of the observed types in this model.
49+ """
50+ type_map = self .get_type_map ()
51+ out_bias = self .atomic_model .get_out_bias ()[0 ]
52+
53+ assert out_bias is not None , "No out_bias found in the model."
54+ assert out_bias .dim () == 2 , "The supported out_bias should be a 2D tensor."
55+ assert out_bias .size (0 ) == len (type_map ), (
56+ "The out_bias shape does not match the type_map length."
57+ )
58+ bias_mask = (
59+ paddle .greater_than (paddle .abs (out_bias ), 1e-6 ).any (axis = - 1 ).detach ().cpu ()
60+ ) # 1e-6 for stability
61+
62+ observed_type_list : list [str ] = []
63+ for i in range (len (type_map )):
64+ if bias_mask [i ]:
65+ observed_type_list .append (type_map [i ])
66+ return observed_type_list
3667
3768 def translated_output_def (self ):
3869 out_def_data = self .model_output_def ().get_data ()
@@ -50,6 +81,8 @@ def translated_output_def(self):
5081 output_def ["atom_virial" ].squeeze (- 3 )
5182 if "mask" in out_def_data :
5283 output_def ["mask" ] = out_def_data ["mask" ]
84+ if self ._hessian_enabled :
85+ output_def ["hessian" ] = out_def_data ["energy_derv_r_derv_r" ]
5386 return output_def
5487
5588 def forward (
@@ -81,14 +114,12 @@ def forward(
81114 model_predict ["atom_virial" ] = model_ret ["energy_derv_c" ].squeeze (
82115 - 3
83116 )
84- else :
85- model_predict ["atom_virial" ] = paddle .zeros (
86- [model_predict ["energy" ].shape [0 ], 1 , 9 ], dtype = paddle .float64
87- )
88117 else :
89118 model_predict ["force" ] = model_ret ["dforce" ]
90119 if "mask" in model_ret :
91120 model_predict ["mask" ] = model_ret ["mask" ]
121+ if self ._hessian_enabled :
122+ model_predict ["hessian" ] = model_ret ["energy_derv_r_derv_r" ].squeeze (- 2 )
92123 else :
93124 model_predict = model_ret
94125 model_predict ["updated_coord" ] += coord
@@ -128,10 +159,6 @@ def forward_lower(
128159 model_predict ["extended_virial" ] = model_ret [
129160 "energy_derv_c"
130161 ].squeeze (- 3 )
131- else :
132- model_predict ["extended_virial" ] = paddle .zeros (
133- [model_predict ["energy" ].shape [0 ], 1 , 9 ], dtype = paddle .float64
134- )
135162 else :
136163 assert model_ret ["dforce" ] is not None
137164 model_predict ["dforce" ] = model_ret ["dforce" ]
0 commit comments