Skip to content

Commit 58ce1aa

Browse files
iProzdnjzjzwanghan-iapcm
authored
feat(pt): add model branch alias (#4883)
Introduces model branch alias and info fields to model configuration, adds utility functions for handling model branch dictionaries, and updates related modules to use alias-based lookup and provide detailed branch information. Enhances multi-task model usability and improves logging of available model branches. example: ``` dp --pt show 0415_compat_new.pt model-branch [2025-08-14 10:05:54,246] DEEPMD WARNING To get the best performance, it is recommended to adjust the number of threads by setting the environment variables OMP_NUM_THREADS, DP_INTRA_OP_PARALLELISM_THREADS, and DP_INTER_OP_PARALLELISM_THREADS. See https://deepmd.rtfd.io/parallelism/ for more information. [2025-08-14 10:05:59,122] DEEPMD INFO This is a multitask model [2025-08-14 10:05:59,122] DEEPMD INFO Available model branches are ['Dai2023Alloy', 'Zhang2023Cathode', 'Gong2023Cluster', 'Yang2023ab', 'UniPero', 'Huang2021Deep-PBE', 'Liu2024Machine', 'Zhang2021Phase', 'Jinag2021Accurate', 'Chen2023Modeling', 'Wen2021Specialising', 'Wang2022Classical', 'Wang2022Tungsten', 'Wu2021Deep', 'Huang2021Deep-PBEsol', 'Transition1x', 'Wang2021Generalizable', 'Wu2021Accurate', 'MPTraj', 'Li2025APEX', 'Shi2024SSE', 'Tuo2023Hybrid', 'Unke2019PhysNet', 'Shi2024Electrolyte', 'ODAC23', 'Alex2D', 'OMAT24', 'SPICE2', 'OC20M', 'OC22', 'Li2025General', 'RANDOM'], where 'RANDOM' means using a randomly initialized fitting net. [2025-08-14 10:05:59,125] DEEPMD INFO Detailed information: +-----------------------+------------------------------+--------------------------------+--------------------------------+ | Model Branch | Alias | description | observed_type | +-----------------------+------------------------------+--------------------------------+--------------------------------+ | Dai2023Alloy | Alloys, Domains_Alloy | The dataset contains | ['La', 'Fe', 'Ho', 'Cu', 'Sn', | | | | structure-energy-force-virial | 'Cd', 'Y', 'Be', 'V', 'Sm', | | | | data for 53 typical metallic | 'In', 'Pr', 'Mo', 'Mn', 'Gd', | | | | elements in alloy systems, | 'Ru', 'Nd', 'Li', 'Tm', 'K', | | | | including ~9000 intermetallic | 'Pt', 'Ir', 'Na', 'Hf', 'Dy', | | | | compounds and FCC, BCC, HCP | 'Ca', 'Nb', 'Au', 'Sr', 'Si', | | | | structures. It consists of two | 'Ge', 'Co', 'W', 'Cr', 'Zn', | | | | parts: DFT-generated relaxed | 'Ag', 'Ti', 'Ni', 'Zr', 'Pd', | | | | and deformed structures, and | 'Os', 'Ta', 'Rh', 'Sc', 'Tb', | | | | randomly distorted structures | 'Al', 'Ga', 'Re', 'Lu', 'Er', | | | | produced covering pure metals, | 'Mg', 'Ce', 'Pb'] | | | | solid solutions, and | | | | | intermetallics with vacancies. | | +-----------------------+------------------------------+--------------------------------+--------------------------------+ | OMAT24 | Default, Materials, Omat24 | OMat24 is a large-scale open | ['La', 'Fe', 'Cu', 'Cd', 'Be', | | | | dataset containing over 110 | 'Ar', 'V', 'Sm', 'In', 'Pm', | | | | million DFT calculations | 'Pr', 'Mn', 'Ru', 'He', 'Nd', | | | | spanning diverse structures | 'Th', 'Pa', 'K', 'Pt', 'Yb', | | | | and compositions. It is | 'Dy', 'Sr', 'Co', 'Np', 'Cr', | | | | designed to support AI-driven | 'Tl', 'Br', 'Se', 'Ni', 'Zr', | | | | materials discovery by | 'Pu', 'O', 'Xe', 'Tb', 'Ga', | | | | providing broad and deep | 'Lu', 'H', 'Ne', 'Er', 'Ce', | | | | coverage of chemical space. | 'I', 'Kr', 'Ho', 'Cs', 'Sn', | | | | | 'Rb', 'Y', 'N', 'F', 'Mo', | | | | | 'Gd', 'B', 'Li', 'Tm', 'Sb', | | | | | 'Ir', 'Hf', 'Na', 'Ca', 'Nb', | | | | | 'Au', 'As', 'Si', 'Ge', 'W', | | | | | 'Zn', 'Hg', 'Ag', 'Bi', 'Ti', | | | | | 'Os', 'Cl', 'Pd', 'P', 'U', | | | | | 'Tc', 'Ta', 'Ba', 'Rh', 'Sc', | | | | | 'C', 'S', 'Te', 'Al', 'Re', | | | | | 'Eu', 'Mg', 'Pb', 'Ac'] | +-----------------------+------------------------------+--------------------------------+--------------------------------+ ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Alias-based multi-task branch selection for evaluation and fine-tuning; new API to query model alias/branch info; show now prints a detailed model-branch table. * **Documentation** * Model config gains optional fields to declare branch aliases and per-branch info (PyTorch-only). * **Examples** * Added a two-task PyTorch example demonstrating aliases, shared components, and per-branch info. * **Tests** * Tests include the new example and now filter out table-like show output. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
1 parent c796862 commit 58ce1aa

File tree

8 files changed

+573
-9
lines changed

8 files changed

+573
-9
lines changed

deepmd/entrypoints/show.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from deepmd.utils.econf_embd import (
88
sort_element_type,
99
)
10+
from deepmd.utils.model_branch_dict import (
11+
OrderedDictTableWrapper,
12+
get_model_dict,
13+
)
1014

1115
log = logging.getLogger(__name__)
1216

@@ -33,10 +37,15 @@ def show(
3337
)
3438
model_branches = list(model_params["model_dict"].keys())
3539
model_branches += ["RANDOM"]
40+
_, model_branch_dict = get_model_dict(model_params["model_dict"])
3641
log.info(
3742
f"Available model branches are {model_branches}, "
3843
f"where 'RANDOM' means using a randomly initialized fitting net."
3944
)
45+
log.info(
46+
"Detailed information: \n"
47+
+ OrderedDictTableWrapper(model_branch_dict).as_table()
48+
)
4049
if "type-map" in ATTRIBUTES:
4150
if model_is_multi_task:
4251
model_branches = list(model_params["model_dict"].keys())
@@ -79,8 +88,22 @@ def show(
7988
total_observed_types_list = []
8089
model_branches = list(model_params["model_dict"].keys())
8190
for branch in model_branches:
82-
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
83-
observed_types = tmp_model.get_observed_types()
91+
if (
92+
model_params["model_dict"][branch]
93+
.get("info", {})
94+
.get("observed_type", None)
95+
is not None
96+
):
97+
observed_type_list = model_params["model_dict"][branch]["info"][
98+
"observed_type"
99+
]
100+
observed_types = {
101+
"type_num": len(observed_type_list),
102+
"observed_type": observed_type_list,
103+
}
104+
else:
105+
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
106+
observed_types = tmp_model.get_observed_types()
84107
log.info(
85108
f"{branch}: Number of observed types: {observed_types['type_num']} "
86109
)

deepmd/pt/infer/deep_eval.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import json
3+
import logging
34
from typing import (
45
TYPE_CHECKING,
56
Any,
@@ -67,10 +68,15 @@
6768
from deepmd.utils.econf_embd import (
6869
sort_element_type,
6970
)
71+
from deepmd.utils.model_branch_dict import (
72+
get_model_dict,
73+
)
7074

7175
if TYPE_CHECKING:
7276
import ase.neighborlist
7377

78+
log = logging.getLogger(__name__)
79+
7480

7581
class DeepEval(DeepEvalBackend):
7682
"""PyTorch backend implementation of DeepEval.
@@ -116,15 +122,36 @@ def __init__(
116122
self.model_def_script = self.input_param
117123
self.multi_task = "model_dict" in self.input_param
118124
if self.multi_task:
125+
model_alias_dict, model_branch_dict = get_model_dict(
126+
self.input_param["model_dict"]
127+
)
119128
model_keys = list(self.input_param["model_dict"].keys())
129+
if head is None and "Default" in model_alias_dict:
130+
head = "Default"
131+
log.info(
132+
f"Using default head {model_alias_dict[head]} for multitask model."
133+
)
120134
if isinstance(head, int):
121135
head = model_keys[0]
122136
assert head is not None, (
123-
f"Head must be set for multitask model! Available heads are: {model_keys}"
137+
f"Head must be set for multitask model! Available heads are: {model_keys}, "
138+
f"use `dp --pt show your_model.pt model-branch` to show detail information."
124139
)
125-
assert head in model_keys, (
126-
f"No head named {head} in model! Available heads are: {model_keys}"
140+
if head not in model_alias_dict:
141+
# preprocess with potentially case-insensitive input
142+
head_lower = head.lower()
143+
for mk in model_alias_dict:
144+
if mk.lower() == head_lower:
145+
# mapped the first matched head
146+
head = mk
147+
break
148+
# replace with alias
149+
assert head in model_alias_dict, (
150+
f"No head or alias named {head} in model! Available heads are: {model_keys},"
151+
f"use `dp --pt show your_model.pt model-branch` to show detail information."
127152
)
153+
head = model_alias_dict[head]
154+
128155
self.input_param = self.input_param["model_dict"][head]
129156
state_dict_head = {"_extra_state": state_dict["_extra_state"]}
130157
for item in state_dict:
@@ -253,6 +280,17 @@ def get_has_hessian(self):
253280
"""Check if the model has hessian."""
254281
return self._has_hessian
255282

283+
def get_model_branch(self):
284+
"""Get the model branch information."""
285+
if "model_dict" in self.model_def_script:
286+
model_alias_dict, model_branch_dict = get_model_dict(
287+
self.model_def_script["model_dict"]
288+
)
289+
return model_alias_dict, model_branch_dict
290+
else:
291+
# single-task model
292+
return {"Default": "Default"}, {"Default": {"alias": [], "info": {}}}
293+
256294
def eval(
257295
self,
258296
coords: np.ndarray,

deepmd/pt/utils/finetune.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from deepmd.utils.finetune import (
1313
FinetuneRuleItem,
1414
)
15+
from deepmd.utils.model_branch_dict import (
16+
get_model_dict,
17+
)
1518

1619
log = logging.getLogger(__name__)
1720

@@ -44,10 +47,13 @@ def get_finetune_rule_single(
4447
)
4548
else:
4649
model_branch_chosen = model_branch_from
47-
assert model_branch_chosen in model_dict_params, (
48-
f"No model branch named '{model_branch_chosen}'! "
50+
model_alias_dict, model_branch_dict = get_model_dict(model_dict_params)
51+
assert model_branch_chosen in model_alias_dict, (
52+
f"No model branch or alias named '{model_branch_chosen}'! "
4953
f"Available ones are {list(model_dict_params.keys())}."
54+
f"Use `dp --pt show your_model.pt model-branch` to show detail information."
5055
)
56+
model_branch_chosen = model_alias_dict[model_branch_chosen]
5157
single_config_chosen = deepcopy(model_dict_params[model_branch_chosen])
5258
old_type_map, new_type_map = (
5359
single_config_chosen["type_map"],

deepmd/utils/argcheck.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,6 +2306,16 @@ def model_args(exclude_hybrid=False):
23062306
def standard_model_args() -> Argument:
23072307
doc_descrpt = "The descriptor of atomic environment."
23082308
doc_fitting = "The fitting of physical properties."
2309+
doc_model_branch_alias = (
2310+
"List of aliases for this model branch. "
2311+
"Multiple aliases can be defined, and any alias can reference this branch throughout the model usage. "
2312+
"Used only in multitask models."
2313+
)
2314+
doc_info = (
2315+
"Dictionary of metadata for this model branch. "
2316+
"Store arbitrary key-value pairs with branch-specific information. "
2317+
"Used only in multitask models."
2318+
)
23092319

23102320
ca = Argument(
23112321
"standard",
@@ -2321,6 +2331,20 @@ def standard_model_args() -> Argument:
23212331
[fitting_variant_type_args()],
23222332
doc=doc_fitting,
23232333
),
2334+
Argument(
2335+
"model_branch_alias",
2336+
list[str],
2337+
optional=True,
2338+
default=[],
2339+
doc=doc_only_pt_supported + doc_model_branch_alias,
2340+
),
2341+
Argument(
2342+
"info",
2343+
dict,
2344+
optional=True,
2345+
default={},
2346+
doc=doc_only_pt_supported + doc_info,
2347+
),
23242348
],
23252349
doc="Standard model, which contains a descriptor and a fitting.",
23262350
)

0 commit comments

Comments
 (0)