Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ test_dp_test_*.out

# Training and model output files
*.pth
*.pte
*.pt2
*.ckpt*
checkpoint
lcurve.out
Expand Down
19 changes: 16 additions & 3 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,26 @@ def show(
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
elif "type_map" in model_params:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
else:
type_map = model.get_type_map()
log.info(f"The type_map is {type_map}")
if "descriptor" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
elif "descriptor" in model_params:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
else:
log.warning(
"Descriptor parameters not available "
"(model was not frozen with training config)."
)
if "fitting-net" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
Expand All @@ -75,9 +83,14 @@ def show(
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
elif "fitting_net" in model_params:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")
else:
log.warning(
"Fitting net parameters not available "
"(model was not frozen with training config)."
)
if "size" in ATTRIBUTES:
size_dict = model.get_model_size()
log_prefix = " for a single branch model" if model_is_multi_task else ""
Expand Down
8 changes: 5 additions & 3 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def freeze(
m.eval()

model_dict = m.serialize()
deserialize_to_file(output, {"model": model_dict}, model_params=model_params)
deserialize_to_file(output, {"model": model_dict, "model_def_script": model_params})
log.info("Saved frozen model to %s", output)


Expand Down Expand Up @@ -344,7 +344,7 @@ def change_bias(
)

model_to_change = BaseModel.deserialize(pte_data["model"])
model_params = None
model_params = pte_data.get("model_def_script")
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
Expand Down Expand Up @@ -440,7 +440,9 @@ def change_bias(
)
)
model_dict = model_to_change.serialize()
deserialize_to_file(output_path, {"model": model_dict})
deserialize_to_file(
output_path, {"model": model_dict, "model_def_script": model_params}
)
log.info(f"Saved model to {output_path}")


Expand Down
89 changes: 40 additions & 49 deletions deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
communicate_extended_output,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableCategory,
OutputVariableDef,
Expand Down Expand Up @@ -59,28 +58,6 @@
import ase.neighborlist


def _reconstruct_model_output_def(metadata: dict) -> ModelOutputDef:
"""Reconstruct ModelOutputDef from stored fitting_output_defs metadata."""
var_defs = []
for vd in metadata["fitting_output_defs"]:
var_defs.append(
OutputVariableDef(
name=vd["name"],
shape=vd["shape"],
reducible=vd["reducible"],
r_differentiable=vd["r_differentiable"],
c_differentiable=vd["c_differentiable"],
atomic=vd["atomic"],
category=vd["category"],
r_hessian=vd["r_hessian"],
magnetic=vd["magnetic"],
intensive=vd["intensive"],
)
)
fitting_output_def = FittingOutputDef(var_defs)
return ModelOutputDef(fitting_output_def)


class DeepEval(DeepEvalBackend):
"""PyTorch Exportable backend implementation of DeepEval.

Expand Down Expand Up @@ -124,9 +101,6 @@ def __init__(
else:
self._load_pte(model_file)

# Reconstruct the model output def from stored fitting output defs
self._model_output_def = _reconstruct_model_output_def(self.metadata)

if isinstance(auto_batch_size, bool):
if auto_batch_size:
self.auto_batch_size = AutoBatchSize()
Expand All @@ -139,14 +113,30 @@ def __init__(
else:
raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")

def _init_from_model_json(self, model_json_str: str) -> None:
"""Deserialize model.json and derive model API from the dpmodel instance."""
from deepmd.pt_expt.model.model import (
BaseModel,
)
from deepmd.pt_expt.utils.serialization import (
_json_to_numpy,
)

model_dict = json.loads(model_json_str)
model_dict = _json_to_numpy(model_dict)
self._dpmodel = BaseModel.deserialize(model_dict["model"])
self.rcut = self._dpmodel.get_rcut()
self.type_map = self._dpmodel.get_type_map()
self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def())

def _load_pte(self, model_file: str) -> None:
"""Load a .pte (torch.export) model file."""
extra_files = {"model_def_script.json": ""}
extra_files = {"model.json": "", "model_def_script.json": ""}
exported = torch.export.load(model_file, extra_files=extra_files)
self.exported_module = exported.module()
self.metadata = json.loads(extra_files["model_def_script.json"])
self.rcut = self.metadata["rcut"]
self.type_map = self.metadata["type_map"]
self._init_from_model_json(extra_files["model.json"])
mds = extra_files["model_def_script.json"]
self._model_def_script = json.loads(mds) if mds else {}

def _load_pt2(self, model_file: str) -> None:
"""Load a .pt2 (AOTInductor) model file."""
Expand All @@ -159,16 +149,17 @@ def _load_pt2(self, model_file: str) -> None:
# Read metadata from the .pt2 ZIP archive
with zipfile.ZipFile(model_file, "r") as zf:
names = zf.namelist()
for required in ("extra/model_def_script.json", "extra/output_keys.json"):
if required not in names:
raise ValueError(
f"Invalid .pt2 file '{model_file}': missing '{required}'"
)
self.metadata = json.loads(zf.read("extra/model_def_script.json"))
self._output_keys = json.loads(zf.read("extra/output_keys.json"))
if "extra/model.json" not in names:
raise ValueError(
f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'"
)
model_json_str = zf.read("extra/model.json").decode("utf-8")
mds = ""
if "extra/model_def_script.json" in names:
mds = zf.read("extra/model_def_script.json").decode("utf-8")

self.rcut = self.metadata["rcut"]
self.type_map = self.metadata["type_map"]
self._init_from_model_json(model_json_str)
self._model_def_script = json.loads(mds) if mds else {}

# Load the AOTInductor model package (.pt2 ZIP archive).
# Uses torch._inductor.aoti_load_package (private API, stable since PyTorch 2.6).
Expand All @@ -189,16 +180,16 @@ def get_type_map(self) -> list[str]:

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this DP."""
return self.metadata["dim_fparam"]
return self._dpmodel.get_dim_fparam()

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this DP."""
return self.metadata["dim_aparam"]
return self._dpmodel.get_dim_aparam()

@property
def model_type(self) -> type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_output_type = self.metadata["model_output_type"]
model_output_type = self._dpmodel.model_output_type()
if "energy" in model_output_type:
return DeepPot
elif "dos" in model_output_type:
Expand All @@ -219,7 +210,7 @@ def get_sel_type(self) -> list[int]:
to the result of the model.
If returning an empty list, all atom types are selected.
"""
return self.metadata["sel_type"]
return self._dpmodel.get_sel_type()

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
Expand Down Expand Up @@ -364,8 +355,8 @@ def _build_nlist_native(
nframes = coords.shape[0]
natoms = coords.shape[1]
rcut = self.rcut
sel = self.metadata["sel"]
mixed_types = self.metadata["mixed_types"]
sel = self._dpmodel.get_sel()
mixed_types = self._dpmodel.mixed_types()

if cells is not None:
box_input = cells.reshape(nframes, 3, 3)
Expand Down Expand Up @@ -476,8 +467,8 @@ def _build_nlist_ase_single(
nlist : np.ndarray, shape (nloc, nsel)
mapping : np.ndarray, shape (nall,)
"""
sel = self.metadata["sel"]
mixed_types = self.metadata["mixed_types"]
sel = self._dpmodel.get_sel()
mixed_types = self._dpmodel.mixed_types()
nsel = sum(sel)

natoms = positions.shape[0]
Expand Down Expand Up @@ -703,8 +694,8 @@ def _get_output_shape(
raise RuntimeError("unknown category")

def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.metadata
"""Get model definition script (training config)."""
return self._model_def_script
Comment thread
wanghan-iapcm marked this conversation as resolved.

def get_model(self) -> torch.nn.Module:
"""Get the exported model module.
Expand Down
8 changes: 2 additions & 6 deletions deepmd/pt_expt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ def _load_model_params(finetune_model: str) -> dict[str, Any]:
)

data = serialize_from_file(finetune_model)
# Prefer embedded model_params (full config); fall back to
# a minimal dict with just type_map for older .pte files.
if "model_params" in data:
return data["model_params"]
return {"type_map": data["model"]["type_map"]}
return data["model_def_script"]
else:
state_dict = torch.load(finetune_model, map_location=DEVICE, weights_only=True)
if "model" in state_dict:
Expand Down Expand Up @@ -82,7 +78,7 @@ def get_finetune_rules(
raise ValueError(
"Cannot use --use-pretrain-script: the pretrained model does not "
"contain full model params. If finetuning from a .pte file, "
"re-freeze it with the latest code so that model_params is embedded."
"re-freeze it with the latest code so that model_def_script is embedded."
)

finetune_from_multi_task = "model_dict" in last_model_params
Expand Down
Loading
Loading