Skip to content

Commit 4ad67d5

Browse files
iProzdnjzjzpre-commit-ci[bot]
authored
feat(pt): add observed-type option for dp show (#4820)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for displaying "observed-type" information in the model reporting tool, showing element types observed during training for single-task and multi-task models. - **Tests** - Updated tests to verify correct reporting of observed types for both single-task and multi-task models. - **Documentation** - Updated documentation to include the new "observed-type" attribute in model information display, with examples and explanations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1eefc8e commit 4ad67d5

9 files changed

Lines changed: 199 additions & 46 deletions

File tree

deepmd/entrypoints/show.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from deepmd.infer.deep_eval import (
55
DeepEval,
66
)
7+
from deepmd.utils.econf_embd import (
8+
sort_element_type,
9+
)
710

811
log = logging.getLogger(__name__)
912

@@ -69,3 +72,34 @@ def show(
6972
log.info(f"Parameter counts{log_prefix}:")
7073
for k in sorted(size_dict):
7174
log.info(f"Parameters in {k}: {size_dict[k]:,}")
75+
76+
if "observed-type" in ATTRIBUTES:
77+
if model_is_multi_task:
78+
log.info("The observed types for each branch: ")
79+
total_observed_types_list = []
80+
model_branches = list(model_params["model_dict"].keys())
81+
for branch in model_branches:
82+
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
83+
observed_types = tmp_model.get_observed_types()
84+
log.info(
85+
f"{branch}: Number of observed types: {observed_types['type_num']} "
86+
)
87+
log.info(
88+
f"{branch}: Observed types: {observed_types['observed_type']} "
89+
)
90+
total_observed_types_list += [
91+
tt
92+
for tt in observed_types["observed_type"]
93+
if tt not in total_observed_types_list
94+
]
95+
log.info(
96+
f"TOTAL number of observed types in the model: {len(total_observed_types_list)} "
97+
)
98+
log.info(
99+
f"TOTAL observed types in the model: {sort_element_type(total_observed_types_list)} "
100+
)
101+
else:
102+
log.info("The observed types for this model: ")
103+
observed_types = model.get_observed_types()
104+
log.info(f"Number of observed types: {observed_types['type_num']} ")
105+
log.info(f"Observed types: {observed_types['observed_type']} ")

deepmd/infer/deep_eval.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ def get_model_size(self) -> dict:
295295
"""Get model parameter count."""
296296
raise NotImplementedError("Not implemented in this backend.")
297297

298+
def get_observed_types(self) -> dict:
299+
"""Get observed types (elements) of the model during data statistics."""
300+
raise NotImplementedError("Not implemented in this backend.")
301+
298302

299303
class DeepEval(ABC):
300304
"""High-level Deep Evaluator interface.
@@ -568,3 +572,7 @@ def get_model_def_script(self) -> dict:
568572
def get_model_size(self) -> dict:
569573
"""Get model parameter count."""
570574
return self.deep_eval.get_model_size()
575+
576+
def get_observed_types(self) -> dict:
577+
"""Get observed types (elements) of the model during data statistics."""
578+
return self.deep_eval.get_observed_types()

deepmd/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,14 @@ def main_parser() -> argparse.ArgumentParser:
851851
)
852852
parser_show.add_argument(
853853
"ATTRIBUTES",
854-
choices=["model-branch", "type-map", "descriptor", "fitting-net", "size"],
854+
choices=[
855+
"model-branch",
856+
"type-map",
857+
"descriptor",
858+
"fitting-net",
859+
"size",
860+
"observed-type",
861+
],
855862
nargs="+",
856863
)
857864
return parser

deepmd/pt/infer/deep_eval.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
to_numpy_array,
6565
to_torch_tensor,
6666
)
67+
from deepmd.utils.econf_embd import (
68+
sort_element_type,
69+
)
6770

6871
if TYPE_CHECKING:
6972
import ase.neighborlist
@@ -98,6 +101,7 @@ def __init__(
98101
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
99102
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
100103
head: Optional[Union[str, int]] = None,
104+
no_jit: bool = False,
101105
**kwargs: Any,
102106
) -> None:
103107
self.output_def = output_def
@@ -130,7 +134,7 @@ def __init__(
130134
] = state_dict[item].clone()
131135
state_dict = state_dict_head
132136
model = get_model(self.input_param).to(DEVICE)
133-
if not self.input_param.get("hessian_mode"):
137+
if not self.input_param.get("hessian_mode") and not no_jit:
134138
model = torch.jit.script(model)
135139
self.dp = ModelWrapper(model)
136140
self.dp.load_state_dict(state_dict)
@@ -648,6 +652,22 @@ def get_model_size(self) -> dict:
648652
"total": sum_param_des + sum_param_fit,
649653
}
650654

655+
def get_observed_types(self) -> dict:
656+
"""Get observed types (elements) of the model during data statistics.
657+
658+
Returns
659+
-------
660+
dict
661+
A dictionary containing the information of observed type in the model:
662+
- 'type_num': the total number of observed types in this model.
663+
- 'observed_type': a list of the observed types in this model.
664+
"""
665+
observed_type_list = self.dp.model["Default"].get_observed_type_list()
666+
return {
667+
"type_num": len(observed_type_list),
668+
"observed_type": sort_element_type(observed_type_list),
669+
}
670+
651671
def eval_descriptor(
652672
self,
653673
coords: np.ndarray,

deepmd/pt/model/model/ener_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,32 @@ def enable_hessian(self):
4444
self.requires_hessian("energy")
4545
self._hessian_enabled = True
4646

47+
@torch.jit.export
48+
def get_observed_type_list(self) -> list[str]:
49+
"""Get observed types (elements) of the model during data statistics.
50+
51+
Returns
52+
-------
53+
observed_type_list: a list of the observed types in this model.
54+
"""
55+
type_map = self.get_type_map()
56+
out_bias = self.atomic_model.get_out_bias()[0]
57+
58+
assert out_bias is not None, "No out_bias found in the model."
59+
assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor."
60+
assert out_bias.size(0) == len(type_map), (
61+
"The out_bias shape does not match the type_map length."
62+
)
63+
bias_mask = (
64+
torch.gt(torch.abs(out_bias), 1e-6).any(dim=-1).detach().cpu()
65+
) # 1e-6 for stability
66+
67+
observed_type_list: list[str] = []
68+
for i in range(len(type_map)):
69+
if bias_mask[i]:
70+
observed_type_list.append(type_map[i])
71+
return observed_type_list
72+
4773
def translated_output_def(self):
4874
out_def_data = self.model_output_def().get_data()
4975
output_def = {

deepmd/pt/model/model/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def compute_or_load_stat(
4848
"""
4949
raise NotImplementedError
5050

51+
@torch.jit.export
52+
def get_observed_type_list(self) -> list[str]:
53+
"""Get observed types (elements) of the model during data statistics.
54+
55+
Returns
56+
-------
57+
observed_type_list: a list of the observed types in this model.
58+
"""
59+
raise NotImplementedError
60+
5161
@torch.jit.export
5262
def get_model_def_script(self) -> str:
5363
"""Get the model definition script."""

deepmd/utils/econf_embd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"electronic_configuration_embedding",
1010
"make_econf_embedding",
1111
"normalized_electronic_configuration_embedding",
12+
"sort_element_type",
1213
"transform_to_spin_rep",
1314
]
1415

@@ -263,3 +264,16 @@ def print_econf_embedding(res: dict[str, np.ndarray]) -> None:
263264
vvstr = ",".join([str(ii) for ii in vv])
264265
space = " " * (2 - len(kk))
265266
print(f'"{kk}"{space} : [{vvstr}],') # noqa: T201
267+
268+
269+
def sort_element_type(elements: list[str]) -> list[str]:
270+
"""Sort element types based on their atomic number."""
271+
272+
def get_atomic_number(symbol):
273+
try:
274+
return element(symbol).atomic_number
275+
except ValueError:
276+
return float("inf")
277+
278+
sorted_elements = sorted(elements, key=lambda x: get_atomic_number(x))
279+
return sorted_elements

doc/model/show-model-info.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dp --pt show <INPUT> <ATTRIBUTES...>
1717
- `descriptor`: Displays the model descriptor parameters.
1818
- `fitting-net`: Displays parameters of the fitting network.
1919
- `size`: (Supported Backends: PyTorch and PaddlePaddle) Shows the parameter counts for various components.
20+
- `observed-type`: (Supported Backends: PyTorch) Shows the observed types (elements) of the model during data statistics. Only energy models are supported now.
2021

2122
## Example Usage
2223

@@ -60,6 +61,12 @@ Depending on the provided attributes and the model type, the output includes:
6061

6162
- Prints the number of parameters for each component (`descriptor`, `fitting-net`, etc.), as well as the total parameter count.
6263

64+
- **observed-type**
65+
66+
- Displays the count and list of observed element types of the model during data statistics.
67+
- For multitask models, it shows the observed types for each branch.
68+
- Note: This info shows the types observed during training data statistics, which may differ from the type map.
69+
6370
## Example Output
6471

6572
For a singletask model, the output might look like:
@@ -73,6 +80,9 @@ Parameter counts:
7380
Parameters in descriptor: 19,350
7481
Parameters in fitting-net: 119,091
7582
Parameters in total: 138,441
83+
The observed types for this model:
84+
Number of observed types: 2
85+
Observed types: ['H', 'O']
7686
```
7787

7888
For a multitask model, if `model-branch` is selected, it will additionally display available branches:

0 commit comments

Comments
 (0)