Skip to content

Commit 0ad36f2

Browse files
update code
1 parent e79782d commit 0ad36f2

2 files changed

Lines changed: 56 additions & 11 deletions

File tree

deepmd/pd/model/model/ener_model.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,37 @@ def __init__(
3333
) -> None:
3434
DPModelCommon.__init__(self)
3535
DPEnergyModel_.__init__(self, *args, **kwargs)
36+
self._hessian_enabled = False
37+
38+
def enable_hessian(self):
39+
raise NotImplementedError(
40+
"Hessian calculation is not implemented yet on PaddlePaddle platform."
41+
)
42+
43+
def get_observed_type_list(self) -> list[str]:
44+
"""Get observed types (elements) of the model during data statistics.
45+
46+
Returns
47+
-------
48+
observed_type_list: a list of the observed types in this model.
49+
"""
50+
type_map = self.get_type_map()
51+
out_bias = self.atomic_model.get_out_bias()[0]
52+
53+
assert out_bias is not None, "No out_bias found in the model."
54+
assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor."
55+
assert out_bias.size(0) == len(type_map), (
56+
"The out_bias shape does not match the type_map length."
57+
)
58+
bias_mask = (
59+
paddle.greater_than(paddle.abs(out_bias), 1e-6).any(axis=-1).detach().cpu()
60+
) # 1e-6 for stability
61+
62+
observed_type_list: list[str] = []
63+
for i in range(len(type_map)):
64+
if bias_mask[i]:
65+
observed_type_list.append(type_map[i])
66+
return observed_type_list
3667

3768
def translated_output_def(self):
3869
out_def_data = self.model_output_def().get_data()
@@ -50,6 +81,8 @@ def translated_output_def(self):
5081
output_def["atom_virial"].squeeze(-3)
5182
if "mask" in out_def_data:
5283
output_def["mask"] = out_def_data["mask"]
84+
if self._hessian_enabled:
85+
output_def["hessian"] = out_def_data["energy_derv_r_derv_r"]
5386
return output_def
5487

5588
def forward(
@@ -81,14 +114,12 @@ def forward(
81114
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
82115
-3
83116
)
84-
else:
85-
model_predict["atom_virial"] = paddle.zeros(
86-
[model_predict["energy"].shape[0], 1, 9], dtype=paddle.float64
87-
)
88117
else:
89118
model_predict["force"] = model_ret["dforce"]
90119
if "mask" in model_ret:
91120
model_predict["mask"] = model_ret["mask"]
121+
if self._hessian_enabled:
122+
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2)
92123
else:
93124
model_predict = model_ret
94125
model_predict["updated_coord"] += coord
@@ -128,10 +159,6 @@ def forward_lower(
128159
model_predict["extended_virial"] = model_ret[
129160
"energy_derv_c"
130161
].squeeze(-3)
131-
else:
132-
model_predict["extended_virial"] = paddle.zeros(
133-
[model_predict["energy"].shape[0], 1, 9], dtype=paddle.float64
134-
)
135162
else:
136163
assert model_ret["dforce"] is not None
137164
model_predict["dforce"] = model_ret["dforce"]

deepmd/pd/model/model/model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from typing import (
3+
NoReturn,
34
Optional,
45
)
56

@@ -8,6 +9,9 @@
89
from deepmd.dpmodel.model.base_model import (
910
make_base_model,
1011
)
12+
from deepmd.pd.utils import (
13+
env,
14+
)
1115
from deepmd.utils.path import (
1216
DPPath,
1317
)
@@ -18,13 +22,16 @@ def __init__(self, *args, **kwargs):
1822
"""Construct a basic model for different tasks."""
1923
paddle.nn.Layer.__init__(self)
2024
self.model_def_script = ""
21-
self.min_nbor_dist = None
25+
self.register_buffer(
26+
"min_nbor_dist",
27+
paddle.to_tensor(-1.0, dtype=paddle.float64, place=env.DEVICE),
28+
)
2229

2330
def compute_or_load_stat(
2431
self,
2532
sampled_func,
2633
stat_file_path: Optional[DPPath] = None,
27-
):
34+
) -> NoReturn:
2835
"""
2936
Compute or load the statistics parameters of the model,
3037
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
@@ -42,13 +49,24 @@ def compute_or_load_stat(
4249
"""
4350
raise NotImplementedError
4451

52+
def get_observed_type_list(self) -> list[str]:
53+
"""Get observed types (elements) of the model during data statistics.
54+
55+
Returns
56+
-------
57+
observed_type_list: a list of the observed types in this model.
58+
"""
59+
raise NotImplementedError
60+
4561
def get_model_def_script(self) -> str:
4662
"""Get the model definition script."""
4763
return self.model_def_script
4864

4965
def get_min_nbor_dist(self) -> Optional[float]:
5066
"""Get the minimum distance between two atoms."""
51-
return self.min_nbor_dist
67+
if self.min_nbor_dist.item() == -1.0:
68+
return None
69+
return self.min_nbor_dist.item()
5270

5371
def get_ntypes(self):
5472
"""Returns the number of element types."""

0 commit comments

Comments
 (0)