Skip to content

Commit d1ba115

Browse files
committed
Add model branch alias and info support
Introduces model branch alias and info fields to model configuration, adds utility functions for handling model branch dictionaries, and updates related modules to use alias-based lookup and provide detailed branch information. Enhances multi-task model usability and improves logging of available model branches.
1 parent 88b71e8 commit d1ba115

5 files changed

Lines changed: 349 additions & 5 deletions

File tree

deepmd/entrypoints/show.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from deepmd.utils.econf_embd import (
88
sort_element_type,
99
)
10+
from deepmd.utils.model_branch_dict import (
11+
OrderedDictTableWrapper,
12+
get_model_dict,
13+
)
1014

1115
log = logging.getLogger(__name__)
1216

@@ -33,10 +37,15 @@ def show(
3337
)
3438
model_branches = list(model_params["model_dict"].keys())
3539
model_branches += ["RANDOM"]
40+
model_alias_dict, model_branch_dict = get_model_dict(model_params["model_dict"])
3641
log.info(
3742
f"Available model branches are {model_branches}, "
3843
f"where 'RANDOM' means using a randomly initialized fitting net."
3944
)
45+
log.info(
46+
"Detailed information: \n"
47+
+ OrderedDictTableWrapper(model_branch_dict).as_table()
48+
)
4049
if "type-map" in ATTRIBUTES:
4150
if model_is_multi_task:
4251
model_branches = list(model_params["model_dict"].keys())

deepmd/pt/infer/deep_eval.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@
6767
from deepmd.utils.econf_embd import (
6868
sort_element_type,
6969
)
70+
from deepmd.utils.model_branch_dict import (
71+
get_model_dict,
72+
)
7073

7174
if TYPE_CHECKING:
7275
import ase.neighborlist
@@ -116,15 +119,25 @@ def __init__(
116119
self.model_def_script = self.input_param
117120
self.multi_task = "model_dict" in self.input_param
118121
if self.multi_task:
122+
model_alias_dict, model_branch_dict = get_model_dict(
123+
self.input_param["model_dict"]
124+
)
119125
model_keys = list(self.input_param["model_dict"].keys())
126+
if head is None and "Default" in model_alias_dict:
127+
head = "Default"
120128
if isinstance(head, int):
121129
head = model_keys[0]
122130
assert head is not None, (
123-
f"Head must be set for multitask model! Available heads are: {model_keys}"
131+
f"Head must be set for multitask model! Available heads are: {model_keys}, "
132+
f"use `dp --pt show your_model.pt model-branch` to show detail information."
124133
)
125-
assert head in model_keys, (
126-
f"No head named {head} in model! Available heads are: {model_keys}"
134+
# replace with alias
135+
assert head in model_alias_dict, (
136+
f"No head or alias named {head} in model! Available heads are: {model_keys},"
137+
f"use `dp --pt show your_model.pt model-branch` to show detail information."
127138
)
139+
head = model_alias_dict[head]
140+
128141
self.input_param = self.input_param["model_dict"][head]
129142
state_dict_head = {"_extra_state": state_dict["_extra_state"]}
130143
for item in state_dict:
@@ -253,6 +266,17 @@ def get_has_hessian(self):
253266
"""Check if the model has hessian."""
254267
return self._has_hessian
255268

269+
def get_model_branch(self):
270+
"""Get the model branch information."""
271+
if "model_dict" in self.model_def_script:
272+
model_alias_dict, model_branch_dict = get_model_dict(
273+
self.model_def_script["model_dict"]
274+
)
275+
return model_alias_dict, model_branch_dict
276+
else:
277+
# single-task model
278+
return {"Default": "Default"}, {"Default": {"alias": [], "info": {}}}
279+
256280
def eval(
257281
self,
258282
coords: np.ndarray,

deepmd/pt/utils/finetune.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from deepmd.utils.finetune import (
1313
FinetuneRuleItem,
1414
)
15+
from deepmd.utils.model_branch_dict import (
16+
get_model_dict,
17+
)
1518

1619
log = logging.getLogger(__name__)
1720

@@ -44,10 +47,13 @@ def get_finetune_rule_single(
4447
)
4548
else:
4649
model_branch_chosen = model_branch_from
47-
assert model_branch_chosen in model_dict_params, (
48-
f"No model branch named '{model_branch_chosen}'! "
50+
model_alias_dict, model_branch_dict = get_model_dict(model_dict_params)
51+
assert model_branch_chosen in model_alias_dict, (
52+
f"No model branch or alias named '{model_branch_chosen}'! "
4953
f"Available ones are {list(model_dict_params.keys())}."
54+
f"Use `dp --pt show your_model.pt model-branch` to show detail information."
5055
)
56+
model_branch_chosen = model_alias_dict[model_branch_chosen]
5157
single_config_chosen = deepcopy(model_dict_params[model_branch_chosen])
5258
old_type_map, new_type_map = (
5359
single_config_chosen["type_map"],

deepmd/utils/argcheck.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,6 +2305,10 @@ def model_args(exclude_hybrid=False):
23052305
def standard_model_args() -> Argument:
23062306
doc_descrpt = "The descriptor of atomic environment."
23072307
doc_fitting = "The fitting of physical properties."
2308+
doc_model_branch_alias = (
2309+
"The alias of this model branch. This is only used in the multi-task model."
2310+
)
2311+
doc_info = "The information of this model branch. This is only used in the multi-task model."
23082312

23092313
ca = Argument(
23102314
"standard",
@@ -2320,6 +2324,20 @@ def standard_model_args() -> Argument:
23202324
[fitting_variant_type_args()],
23212325
doc=doc_fitting,
23222326
),
2327+
Argument(
2328+
"model_branch_alias",
2329+
list,
2330+
optional=True,
2331+
default=[],
2332+
doc=doc_only_pt_supported + doc_model_branch_alias,
2333+
),
2334+
Argument(
2335+
"info",
2336+
dict,
2337+
optional=True,
2338+
default={},
2339+
doc=doc_only_pt_supported + doc_info,
2340+
),
23232341
],
23242342
doc="Standard model, which contains a descriptor and a fitting.",
23252343
)

0 commit comments

Comments
 (0)