Skip to content

Commit 6077e1d

Browse files
committed
only support energy models
1 parent a227b40 commit 6077e1d

4 files changed

Lines changed: 40 additions & 17 deletions

File tree

deepmd/pt/infer/deep_eval.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -662,23 +662,10 @@ def get_observed_types(self) -> dict:
662662
- 'type_num': the total number of observed types in this model.
663663
- 'observed_type': a list of the observed types in this model.
664664
"""
665-
buffers_dict = dict(self.dp.named_buffers())
666-
type_map = np.array(self.type_map)
667-
out_bias = None
668-
for k in buffers_dict:
669-
if ".out_bias" in k:
670-
# only use out_bias in the first fitting out_def
671-
out_bias = buffers_dict[k].detach().cpu().numpy()[0]
672-
break
673-
assert out_bias is not None, "No out_bias found in the model buffers."
674-
assert len(out_bias.shape) == 2, "The supported out_bias should be a 2D array."
675-
assert out_bias.shape[0] == len(type_map), (
676-
"The out_bias shape does not match the type map length."
677-
)
678-
bias_mask = (np.abs(out_bias) > 1e-6).any(-1) # 1e-6 for stability
665+
observed_type_list = self.dp.model["Default"].get_observed_type_list()
679666
return {
680-
"type_num": bias_mask.sum(),
681-
"observed_type": sort_element_type(type_map[bias_mask].tolist()),
667+
"type_num": len(observed_type_list),
668+
"observed_type": sort_element_type(observed_type_list),
682669
}
683670

684671
def eval_descriptor(

deepmd/pt/model/model/ener_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,32 @@ def enable_hessian(self):
4444
self.requires_hessian("energy")
4545
self._hessian_enabled = True
4646

47+
@torch.jit.export
48+
def get_observed_type_list(self) -> list[str]:
49+
"""Get observed types (elements) of the model during data statistics.
50+
51+
Returns
52+
-------
53+
observed_type_list: a list of the observed types in this model.
54+
"""
55+
type_map = self.get_type_map()
56+
out_bias = self.atomic_model.get_out_bias()[0]
57+
58+
assert out_bias is not None, "No out_bias found in the model."
59+
assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor."
60+
assert out_bias.size(0) == len(type_map), (
61+
"The out_bias shape does not match the type_map length."
62+
)
63+
bias_mask = torch.gt(torch.abs(out_bias), 1e-6).any(
64+
dim=-1
65+
) # 1e-6 for stability
66+
67+
observed_type_list: list[str] = []
68+
for i in range(len(type_map)):
69+
if bias_mask[i]:
70+
observed_type_list.append(type_map[i])
71+
return observed_type_list
72+
4773
def translated_output_def(self):
4874
out_def_data = self.model_output_def().get_data()
4975
output_def = {

deepmd/pt/model/model/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def compute_or_load_stat(
4848
"""
4949
raise NotImplementedError
5050

51+
@torch.jit.export
52+
def get_observed_type_list(self) -> list[str]:
53+
"""Get observed types (elements) of the model during data statistics.
54+
55+
Returns
56+
-------
57+
observed_type_list: a list of the observed types in this model.
58+
"""
59+
raise NotImplementedError
60+
5161
@torch.jit.export
5262
def get_model_def_script(self) -> str:
5363
"""Get the model definition script."""

doc/model/show-model-info.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dp --pt show <INPUT> <ATTRIBUTES...>
1717
- `descriptor`: Displays the model descriptor parameters.
1818
- `fitting-net`: Displays parameters of the fitting network.
1919
- `size`: (Supported Backends: PyTorch and PaddlePaddle) Shows the parameter counts for various components.
20-
- `observed-type`: (Supported Backends: PyTorch) Shows the observed types (elements) of the model during data statistics.
20+
- `observed-type`: (Supported Backends: PyTorch) Shows the observed types (elements) of the model during data statistics. Only energy models are supported now.
2121

2222
## Example Usage
2323

0 commit comments

Comments
 (0)