Skip to content

Commit e542798

Browse files
committed
Optimize observed type retrieval in show.py
Improve efficiency by directly using observed_type from model parameters if available, avoiding unnecessary DeepEval instantiation. Update test to filter out table lines in output parsing.
1 parent d1ba115 commit e542798

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

deepmd/entrypoints/show.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,22 @@ def show(
8888
total_observed_types_list = []
8989
model_branches = list(model_params["model_dict"].keys())
9090
for branch in model_branches:
91-
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
92-
observed_types = tmp_model.get_observed_types()
91+
if (
92+
model_params["model_dict"][branch]
93+
.get("info", {})
94+
.get("observed_type", None)
95+
is not None
96+
):
97+
observed_type_list = model_params["model_dict"][branch]["info"][
98+
"observed_type"
99+
]
100+
observed_types = {
101+
"type_num": len(observed_type_list),
102+
"observed_type": sort_element_type(observed_type_list),
103+
}
104+
else:
105+
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
106+
observed_types = tmp_model.get_observed_types()
93107
log.info(
94108
f"{branch}: Number of observed types: {observed_types['type_num']} "
95109
)

source/tests/pt/test_dp_show.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,13 @@ def test_checkpoint(self) -> None:
166166
with redirect_stderr(io.StringIO()) as f:
167167
run_dp(f"dp --pt show {INPUT} {ATTRIBUTES}")
168168
results = [
169-
res for res in f.getvalue().split("\n")[:-1] if "DEEPMD WARNING" not in res
170-
] # filter out warnings
169+
res
170+
for res in f.getvalue().split("\n")[:-1]
171+
if "DEEPMD WARNING" not in res
172+
and "|" not in res
173+
and "+-" not in res
174+
and "Detailed information" not in res
175+
] # filter out warnings and tables
171176
assert "This is a multitask model" in results[0]
172177
assert (
173178
"Available model branches are ['model_1', 'model_2', 'RANDOM'], "

0 commit comments

Comments
 (0)