Skip to content

Commit dd8e492

Browse files
wanghan-iapcmHan Wang
andauthored
refactor(pt_expt): use model API for inference, consistent file naming (#5354)
## Summary ### Problem Two inconsistencies in `.pt2`/`.pte` files: 1. **Python reads a flat metadata dict instead of using the model API.** Other backends (`.dp`/`.yaml`, `.pth`) deserialize the model and query it directly for `get_rcut()`, `get_sel()`, `model_output_type()`, etc. The `.pt2`/`.pte` backend was reading these from a metadata dict stored at export time, duplicating model logic. 2. **Inconsistent file naming.** `model_def_script.json` stored C++ runtime metadata in `.pt2`/`.pte`, but training config in `.pth`. Training config lived separately in `model_params.json`. `output_keys.json` was a standalone file that logically belongs with metadata. ### Solution **Python inference**: `DeepEval` now deserializes `model.json` into a dpmodel instance (`self._dpmodel`) at load time and delegates all API calls to it. `_reconstruct_model_output_def()` is removed. **File layout renamed** so that each filename means the same thing across `.pth` and `.pt2`/`.pte`: | File | Before | After | |------|--------|-------| | `model_def_script.json` | C++ metadata | **Training config** (matches `.pth`) | | `metadata.json` | *(did not exist)* | **C++ metadata** + output_keys | | `model_params.json` | Training config | **Removed** | | `output_keys.json` | Output key list | **Removed** (merged into `metadata.json`) | | `model.json` | Full serialized model | No change | **C++ inference** (`DeepPotPTExpt.cc`): Updated to read `extra/metadata.json` instead of `extra/model_def_script.json`, and reads `output_keys` from the metadata dict instead of a separate `output_keys.json`. **Why `metadata.json` still exists**: C++ inference cannot deserialize `model.json` to call model API methods. The alternative — compiling methods like `get_rcut()`, `get_sel()` as additional AOTInductor entry points — was benchmarked and rejected: - **Compilation overhead**: ~12s per trivial constant-returning function (C++ codegen + compile + link). With ~8 methods, that adds ~1.5 min to freeze time. - **String outputs**: `get_type_map()` returns strings. `torch.export` only supports tensor I/O — encoding strings as int tensors adds complexity for no benefit. - **These are constants**: `rcut`, `sel`, `type_map` never change after export. A flat JSON file is the simplest and fastest solution. ### Other changes - `compress` and `change_bias` entrypoints now preserve training config through `.pte`/`.pt2` round-trips - `.gitignore` updated to exclude `.pte`/`.pt2` model files - `_collect_metadata()` drops `model_output_type` and `sel_type` (not used by C++; Python now gets them from the model) ## Test plan - [x] `source/tests/pt_expt/infer/test_deep_eval.py` — 36/36 pass (`.pte` + `.pt2`) - New: `test_model_api_delegation`, `test_get_model_def_script_with_params` - Updated: `test_get_model_def_script`, `test_pt2_has_metadata`, `test_dynamic_shapes` - [x] `source/tests/pt_expt/model/` — 50/50 pass (frozen, compression, serialization) - [x] `source/tests/pt_expt/test_change_bias.py` — new tests for `.pte`/`.pt2` model_def_script preservation - [x] C++ tests — 3/3 suites pass (`.pt2` models regenerated with new `metadata.json`) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Preserve and restore training configuration when freezing or modifying frozen models; clearer messages when training config is absent. * **Refactor** * Consolidated metadata layout inside exported model archives for more consistent loading across formats and runtimes. * **Tests** * Added and updated tests to validate config preservation, metadata delegation, and archive contents. * **Chores** * Extended ignore patterns to skip additional model file extensions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent bf3e483 commit dd8e492

File tree

9 files changed

+242
-111
lines changed

9 files changed

+242
-111
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ test_dp_test_*.out
6262

6363
# Training and model output files
6464
*.pth
65+
*.pte
66+
*.pt2
6567
*.ckpt*
6668
checkpoint
6769
lcurve.out

deepmd/entrypoints/show.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,26 @@ def show(
5555
for branch in model_branches:
5656
type_map = model_params["model_dict"][branch]["type_map"]
5757
log.info(f"The type_map of branch {branch} is {type_map}")
58-
else:
58+
elif "type_map" in model_params:
5959
type_map = model_params["type_map"]
6060
log.info(f"The type_map is {type_map}")
61+
else:
62+
type_map = model.get_type_map()
63+
log.info(f"The type_map is {type_map}")
6164
if "descriptor" in ATTRIBUTES:
6265
if model_is_multi_task:
6366
model_branches = list(model_params["model_dict"].keys())
6467
for branch in model_branches:
6568
descriptor = model_params["model_dict"][branch]["descriptor"]
6669
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
67-
else:
70+
elif "descriptor" in model_params:
6871
descriptor = model_params["descriptor"]
6972
log.info(f"The descriptor parameter is {descriptor}")
73+
else:
74+
log.warning(
75+
"Descriptor parameters not available "
76+
"(model was not frozen with training config)."
77+
)
7078
if "fitting-net" in ATTRIBUTES:
7179
if model_is_multi_task:
7280
model_branches = list(model_params["model_dict"].keys())
@@ -75,9 +83,14 @@ def show(
7583
log.info(
7684
f"The fitting_net parameter of branch {branch} is {fitting_net}"
7785
)
78-
else:
86+
elif "fitting_net" in model_params:
7987
fitting_net = model_params["fitting_net"]
8088
log.info(f"The fitting_net parameter is {fitting_net}")
89+
else:
90+
log.warning(
91+
"Fitting net parameters not available "
92+
"(model was not frozen with training config)."
93+
)
8194
if "size" in ATTRIBUTES:
8295
size_dict = model.get_model_size()
8396
log_prefix = " for a single branch model" if model_is_multi_task else ""

deepmd/pt_expt/entrypoints/main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def freeze(
259259
m.eval()
260260

261261
model_dict = m.serialize()
262-
deserialize_to_file(output, {"model": model_dict}, model_params=model_params)
262+
deserialize_to_file(output, {"model": model_dict, "model_def_script": model_params})
263263
log.info("Saved frozen model to %s", output)
264264

265265

@@ -344,7 +344,7 @@ def change_bias(
344344
)
345345

346346
model_to_change = BaseModel.deserialize(pte_data["model"])
347-
model_params = None
347+
model_params = pte_data.get("model_def_script")
348348
else:
349349
raise RuntimeError(
350350
"The model provided must be a checkpoint file with a .pt extension "
@@ -440,7 +440,9 @@ def change_bias(
440440
)
441441
)
442442
model_dict = model_to_change.serialize()
443-
deserialize_to_file(output_path, {"model": model_dict})
443+
deserialize_to_file(
444+
output_path, {"model": model_dict, "model_def_script": model_params}
445+
)
444446
log.info(f"Saved model to {output_path}")
445447

446448

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
communicate_extended_output,
1717
)
1818
from deepmd.dpmodel.output_def import (
19-
FittingOutputDef,
2019
ModelOutputDef,
2120
OutputVariableCategory,
2221
OutputVariableDef,
@@ -59,28 +58,6 @@
5958
import ase.neighborlist
6059

6160

62-
def _reconstruct_model_output_def(metadata: dict) -> ModelOutputDef:
63-
"""Reconstruct ModelOutputDef from stored fitting_output_defs metadata."""
64-
var_defs = []
65-
for vd in metadata["fitting_output_defs"]:
66-
var_defs.append(
67-
OutputVariableDef(
68-
name=vd["name"],
69-
shape=vd["shape"],
70-
reducible=vd["reducible"],
71-
r_differentiable=vd["r_differentiable"],
72-
c_differentiable=vd["c_differentiable"],
73-
atomic=vd["atomic"],
74-
category=vd["category"],
75-
r_hessian=vd["r_hessian"],
76-
magnetic=vd["magnetic"],
77-
intensive=vd["intensive"],
78-
)
79-
)
80-
fitting_output_def = FittingOutputDef(var_defs)
81-
return ModelOutputDef(fitting_output_def)
82-
83-
8461
class DeepEval(DeepEvalBackend):
8562
"""PyTorch Exportable backend implementation of DeepEval.
8663
@@ -124,9 +101,6 @@ def __init__(
124101
else:
125102
self._load_pte(model_file)
126103

127-
# Reconstruct the model output def from stored fitting output defs
128-
self._model_output_def = _reconstruct_model_output_def(self.metadata)
129-
130104
if isinstance(auto_batch_size, bool):
131105
if auto_batch_size:
132106
self.auto_batch_size = AutoBatchSize()
@@ -139,14 +113,30 @@ def __init__(
139113
else:
140114
raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")
141115

116+
def _init_from_model_json(self, model_json_str: str) -> None:
117+
"""Deserialize model.json and derive model API from the dpmodel instance."""
118+
from deepmd.pt_expt.model.model import (
119+
BaseModel,
120+
)
121+
from deepmd.pt_expt.utils.serialization import (
122+
_json_to_numpy,
123+
)
124+
125+
model_dict = json.loads(model_json_str)
126+
model_dict = _json_to_numpy(model_dict)
127+
self._dpmodel = BaseModel.deserialize(model_dict["model"])
128+
self.rcut = self._dpmodel.get_rcut()
129+
self.type_map = self._dpmodel.get_type_map()
130+
self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def())
131+
142132
def _load_pte(self, model_file: str) -> None:
143133
"""Load a .pte (torch.export) model file."""
144-
extra_files = {"model_def_script.json": ""}
134+
extra_files = {"model.json": "", "model_def_script.json": ""}
145135
exported = torch.export.load(model_file, extra_files=extra_files)
146136
self.exported_module = exported.module()
147-
self.metadata = json.loads(extra_files["model_def_script.json"])
148-
self.rcut = self.metadata["rcut"]
149-
self.type_map = self.metadata["type_map"]
137+
self._init_from_model_json(extra_files["model.json"])
138+
mds = extra_files["model_def_script.json"]
139+
self._model_def_script = json.loads(mds) if mds else {}
150140

151141
def _load_pt2(self, model_file: str) -> None:
152142
"""Load a .pt2 (AOTInductor) model file."""
@@ -159,16 +149,17 @@ def _load_pt2(self, model_file: str) -> None:
159149
# Read metadata from the .pt2 ZIP archive
160150
with zipfile.ZipFile(model_file, "r") as zf:
161151
names = zf.namelist()
162-
for required in ("extra/model_def_script.json", "extra/output_keys.json"):
163-
if required not in names:
164-
raise ValueError(
165-
f"Invalid .pt2 file '{model_file}': missing '{required}'"
166-
)
167-
self.metadata = json.loads(zf.read("extra/model_def_script.json"))
168-
self._output_keys = json.loads(zf.read("extra/output_keys.json"))
152+
if "extra/model.json" not in names:
153+
raise ValueError(
154+
f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'"
155+
)
156+
model_json_str = zf.read("extra/model.json").decode("utf-8")
157+
mds = ""
158+
if "extra/model_def_script.json" in names:
159+
mds = zf.read("extra/model_def_script.json").decode("utf-8")
169160

170-
self.rcut = self.metadata["rcut"]
171-
self.type_map = self.metadata["type_map"]
161+
self._init_from_model_json(model_json_str)
162+
self._model_def_script = json.loads(mds) if mds else {}
172163

173164
# Load the AOTInductor model package (.pt2 ZIP archive).
174165
# Uses torch._inductor.aoti_load_package (private API, stable since PyTorch 2.6).
@@ -189,16 +180,16 @@ def get_type_map(self) -> list[str]:
189180

190181
def get_dim_fparam(self) -> int:
191182
"""Get the number (dimension) of frame parameters of this DP."""
192-
return self.metadata["dim_fparam"]
183+
return self._dpmodel.get_dim_fparam()
193184

194185
def get_dim_aparam(self) -> int:
195186
"""Get the number (dimension) of atomic parameters of this DP."""
196-
return self.metadata["dim_aparam"]
187+
return self._dpmodel.get_dim_aparam()
197188

198189
@property
199190
def model_type(self) -> type["DeepEvalWrapper"]:
200191
"""The the evaluator of the model type."""
201-
model_output_type = self.metadata["model_output_type"]
192+
model_output_type = self._dpmodel.model_output_type()
202193
if "energy" in model_output_type:
203194
return DeepPot
204195
elif "dos" in model_output_type:
@@ -219,7 +210,7 @@ def get_sel_type(self) -> list[int]:
219210
to the result of the model.
220211
If returning an empty list, all atom types are selected.
221212
"""
222-
return self.metadata["sel_type"]
213+
return self._dpmodel.get_sel_type()
223214

224215
def get_numb_dos(self) -> int:
225216
"""Get the number of DOS."""
@@ -364,8 +355,8 @@ def _build_nlist_native(
364355
nframes = coords.shape[0]
365356
natoms = coords.shape[1]
366357
rcut = self.rcut
367-
sel = self.metadata["sel"]
368-
mixed_types = self.metadata["mixed_types"]
358+
sel = self._dpmodel.get_sel()
359+
mixed_types = self._dpmodel.mixed_types()
369360

370361
if cells is not None:
371362
box_input = cells.reshape(nframes, 3, 3)
@@ -476,8 +467,8 @@ def _build_nlist_ase_single(
476467
nlist : np.ndarray, shape (nloc, nsel)
477468
mapping : np.ndarray, shape (nall,)
478469
"""
479-
sel = self.metadata["sel"]
480-
mixed_types = self.metadata["mixed_types"]
470+
sel = self._dpmodel.get_sel()
471+
mixed_types = self._dpmodel.mixed_types()
481472
nsel = sum(sel)
482473

483474
natoms = positions.shape[0]
@@ -703,8 +694,8 @@ def _get_output_shape(
703694
raise RuntimeError("unknown category")
704695

705696
def get_model_def_script(self) -> dict:
706-
"""Get model definition script."""
707-
return self.metadata
697+
"""Get model definition script (training config)."""
698+
return self._model_def_script
708699

709700
def get_model(self) -> torch.nn.Module:
710701
"""Get the exported model module.

deepmd/pt_expt/utils/finetune.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,7 @@ def _load_model_params(finetune_model: str) -> dict[str, Any]:
3333
)
3434

3535
data = serialize_from_file(finetune_model)
36-
# Prefer embedded model_params (full config); fall back to
37-
# a minimal dict with just type_map for older .pte files.
38-
if "model_params" in data:
39-
return data["model_params"]
40-
return {"type_map": data["model"]["type_map"]}
36+
return data["model_def_script"]
4137
else:
4238
state_dict = torch.load(finetune_model, map_location=DEVICE, weights_only=True)
4339
if "model" in state_dict:
@@ -82,7 +78,7 @@ def get_finetune_rules(
8278
raise ValueError(
8379
"Cannot use --use-pretrain-script: the pretrained model does not "
8480
"contain full model params. If finetuning from a .pte file, "
85-
"re-freeze it with the latest code so that model_params is embedded."
81+
"re-freeze it with the latest code so that model_def_script is embedded."
8682
)
8783

8884
finetune_from_multi_task = "model_dict" in last_model_params

0 commit comments

Comments
 (0)