Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from deepmd.utils.econf_embd import (
sort_element_type,
)
from deepmd.utils.model_branch_dict import (
OrderedDictTableWrapper,
get_model_dict,
)

log = logging.getLogger(__name__)

Expand All @@ -33,10 +37,15 @@ def show(
)
model_branches = list(model_params["model_dict"].keys())
model_branches += ["RANDOM"]
_, model_branch_dict = get_model_dict(model_params["model_dict"])
log.info(
f"Available model branches are {model_branches}, "
f"where 'RANDOM' means using a randomly initialized fitting net."
)
log.info(
"Detailed information: \n"
+ OrderedDictTableWrapper(model_branch_dict).as_table()
)
if "type-map" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
Expand Down Expand Up @@ -79,8 +88,22 @@ def show(
total_observed_types_list = []
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
observed_types = tmp_model.get_observed_types()
if (
model_params["model_dict"][branch]
.get("info", {})
.get("observed_type", None)
is not None
):
observed_type_list = model_params["model_dict"][branch]["info"][
"observed_type"
]
observed_types = {
"type_num": len(observed_type_list),
"observed_type": observed_type_list,
}
else:
tmp_model = DeepEval(INPUT, head=branch, no_jit=True)
observed_types = tmp_model.get_observed_types()
log.info(
f"{branch}: Number of observed types: {observed_types['type_num']} "
)
Expand Down
44 changes: 41 additions & 3 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -67,10 +68,15 @@
from deepmd.utils.econf_embd import (
sort_element_type,
)
from deepmd.utils.model_branch_dict import (
get_model_dict,
)

if TYPE_CHECKING:
import ase.neighborlist

log = logging.getLogger(__name__)


class DeepEval(DeepEvalBackend):
"""PyTorch backend implementation of DeepEval.
Expand Down Expand Up @@ -116,15 +122,36 @@ def __init__(
self.model_def_script = self.input_param
self.multi_task = "model_dict" in self.input_param
if self.multi_task:
model_alias_dict, model_branch_dict = get_model_dict(
self.input_param["model_dict"]
)
model_keys = list(self.input_param["model_dict"].keys())
if head is None and "Default" in model_alias_dict:
head = "Default"
Comment thread
iProzd marked this conversation as resolved.
log.info(
f"Using default head {model_alias_dict[head]} for multitask model."
)
if isinstance(head, int):
head = model_keys[0]
assert head is not None, (
Comment thread
iProzd marked this conversation as resolved.
f"Head must be set for multitask model! Available heads are: {model_keys}"
f"Head must be set for multitask model! Available heads are: {model_keys}, "
f"use `dp --pt show your_model.pt model-branch` to show detail information."
)
Comment thread
iProzd marked this conversation as resolved.
assert head in model_keys, (
f"No head named {head} in model! Available heads are: {model_keys}"
if head not in model_alias_dict:
# preprocess with potentially case-insensitive input
head_lower = head.lower()
for mk in model_alias_dict:
if mk.lower() == head_lower:
# mapped the first matched head
head = mk
break
# replace with alias
assert head in model_alias_dict, (
f"No head or alias named {head} in model! Available heads are: {model_keys},"
f"use `dp --pt show your_model.pt model-branch` to show detail information."
)
head = model_alias_dict[head]
Comment thread
iProzd marked this conversation as resolved.

self.input_param = self.input_param["model_dict"][head]
state_dict_head = {"_extra_state": state_dict["_extra_state"]}
for item in state_dict:
Expand Down Expand Up @@ -253,6 +280,17 @@ def get_has_hessian(self):
"""Check if the model has hessian."""
return self._has_hessian

def get_model_branch(self):
Comment thread
wanghan-iapcm marked this conversation as resolved.
"""Get the model branch information."""
if "model_dict" in self.model_def_script:
model_alias_dict, model_branch_dict = get_model_dict(
self.model_def_script["model_dict"]
)
return model_alias_dict, model_branch_dict
else:
# single-task model
return {"Default": "Default"}, {"Default": {"alias": [], "info": {}}}

def eval(
self,
coords: np.ndarray,
Expand Down
10 changes: 8 additions & 2 deletions deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from deepmd.utils.finetune import (
FinetuneRuleItem,
)
from deepmd.utils.model_branch_dict import (
get_model_dict,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,10 +47,13 @@ def get_finetune_rule_single(
)
else:
model_branch_chosen = model_branch_from
assert model_branch_chosen in model_dict_params, (
f"No model branch named '{model_branch_chosen}'! "
model_alias_dict, model_branch_dict = get_model_dict(model_dict_params)
assert model_branch_chosen in model_alias_dict, (
f"No model branch or alias named '{model_branch_chosen}'! "
f"Available ones are {list(model_dict_params.keys())}."
f"Use `dp --pt show your_model.pt model-branch` to show detail information."
)
model_branch_chosen = model_alias_dict[model_branch_chosen]
single_config_chosen = deepcopy(model_dict_params[model_branch_chosen])
Comment thread
iProzd marked this conversation as resolved.
old_type_map, new_type_map = (
single_config_chosen["type_map"],
Expand Down
24 changes: 24 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,6 +2305,16 @@ def model_args(exclude_hybrid=False):
def standard_model_args() -> Argument:
doc_descrpt = "The descriptor of atomic environment."
doc_fitting = "The fitting of physical properties."
doc_model_branch_alias = (
"List of aliases for this model branch. "
"Multiple aliases can be defined, and any alias can reference this branch throughout the model usage. "
"Used only in multitask models."
)
doc_info = (
"Dictionary of metadata for this model branch. "
"Store arbitrary key-value pairs with branch-specific information. "
"Used only in multitask models."
)

ca = Argument(
"standard",
Expand All @@ -2320,6 +2330,20 @@ def standard_model_args() -> Argument:
[fitting_variant_type_args()],
doc=doc_fitting,
),
Argument(
"model_branch_alias",
list[str],
optional=True,
default=[],
doc=doc_only_pt_supported + doc_model_branch_alias,
),
Argument(
"info",
dict,
optional=True,
default={},
doc=doc_only_pt_supported + doc_info,
),
],
doc="Standard model, which contains a descriptor and a fitting.",
)
Expand Down
Loading