Skip to content

Commit 02afe57

Browse files
committed
feat(pt_expt): make model.json optional in .pt2/.pte loading
The pt_expt DeepEval's inference path runs through aoti_load_package / the exported module; `self._dpmodel` is only used to resolve metadata (rcut / type_map / atomic_output_def / dim_fparam / ...), which is already available in extra/metadata.json (the contract the C++ reader DeepPotPTExpt enforces). Drop the requirement that extra/model.json be present: * _load_pt2 / _load_pte: model.json is optional; metadata.json is now the minimum contract. * _init_from_metadata: reconstructs ModelOutputDef from the serialised fitting_output_defs and hoists sel / mixed_types to plain attributes. * get_dim_fparam / get_dim_aparam / get_sel_type / model_type / get_use_spin: fall back to metadata when _dpmodel is None. * eval_descriptor / eval_typeebd / eval_fitting_last_layer: raise a descriptive NotImplementedError in metadata-only mode (they inspect the dpmodel instance directly). Also fixes two metadata-completeness gaps so metadata-only load is exact: * _collect_metadata: add the `sel_type` field so get_sel_type works without a dpmodel round-trip (relevant for dipole / polar / wfc). * _collect_metadata: force vdef.category to plain int for deterministic JSON serialisation across Python versions. Archives produced by existing pt_expt serialisation still contain model.json and continue to use the dpmodel path unchanged. Regression covered by 77 existing tests in test_deep_eval.py + a dedicated new suite (test_deep_eval_metadata_only.py) that strips extra/model.json and asserts bitwise parity against the full archive.
1 parent 54f42d9 commit 02afe57

3 files changed

Lines changed: 406 additions & 24 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 160 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def _init_from_model_json(self, model_json_str: str) -> None:
140140

141141
self.rcut = self._dpmodel.get_rcut()
142142
self.type_map = self._dpmodel.get_type_map()
143+
# Hoist sel / mixed_types to plain attributes so the inference hot
144+
# path (`_build_nlist_*`) does not need the dpmodel instance. This
145+
# matches what `_init_from_metadata` sets and keeps both code paths
146+
# numerically identical.
147+
self.sel = list(self._dpmodel.get_sel())
148+
self.mixed_types = bool(self._dpmodel.mixed_types())
143149
if self._is_spin:
144150
self._model_output_def = ModelOutputDef(
145151
FittingOutputDef(
@@ -159,23 +165,89 @@ def _init_from_model_json(self, model_json_str: str) -> None:
159165
else:
160166
self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def())
161167

168+
def _init_from_metadata(self) -> None:
169+
"""Initialize DeepEval from ``extra/metadata.json`` alone.
170+
171+
Used when the ``.pt2`` / ``.pte`` archive ships no ``extra/model.json``
172+
(e.g. for backends that do not travel through the dpmodel round-trip).
173+
The metadata contract is the same one the C++ ``DeepPotPTExpt``
174+
reader consumes, so anything that validates against the C++ side
175+
automatically validates here.
176+
177+
``self._dpmodel`` is left as ``None`` to signal the metadata-only
178+
mode. Inference does not need it: it runs through
179+
``aoti_load_package`` / the exported module and uses plain
180+
attributes (``self.rcut``, ``self.sel``, ``self.mixed_types``,
181+
``self._model_output_def``) for all metadata-level queries.
182+
"""
183+
self._dpmodel = None
184+
self._is_spin = bool(self.metadata.get("is_spin", False))
185+
self.rcut = float(self.metadata["rcut"])
186+
self.type_map = list(self.metadata["type_map"])
187+
self.sel = [int(s) for s in self.metadata["sel"]]
188+
self.mixed_types = bool(self.metadata["mixed_types"])
189+
190+
fitting_defs = []
191+
for vdef in self.metadata["fitting_output_defs"]:
192+
fitting_defs.append(
193+
OutputVariableDef(
194+
name=vdef["name"],
195+
shape=list(vdef["shape"]),
196+
reducible=vdef.get("reducible", False),
197+
r_differentiable=vdef.get("r_differentiable", False),
198+
c_differentiable=vdef.get("c_differentiable", False),
199+
atomic=vdef.get("atomic", True),
200+
category=int(
201+
vdef.get("category", OutputVariableCategory.OUT.value)
202+
),
203+
r_hessian=vdef.get("r_hessian", False),
204+
magnetic=vdef.get("magnetic", False),
205+
intensive=vdef.get("intensive", False),
206+
)
207+
)
208+
self._model_output_def = ModelOutputDef(FittingOutputDef(fitting_defs))
209+
162210
def _load_pte(self, model_file: str) -> None:
163-
"""Load a .pte (torch.export) model file."""
211+
"""Load a .pte (torch.export) model file.
212+
213+
``model.json`` is optional: when present it is used to reconstruct
214+
the dpmodel instance (enabling dpmodel-level introspection such as
215+
``eval_descriptor``); when absent we fall back to pure metadata
216+
mode via :meth:`_init_from_metadata`. ``metadata.json`` is the
217+
only contract the inference path actually requires.
218+
"""
164219
extra_files = {
165220
"model.json": "",
166221
"model_def_script.json": "",
167222
"metadata.json": "",
168223
}
169224
exported = torch.export.load(model_file, extra_files=extra_files)
170225
self.exported_module = exported.module()
171-
self._init_from_model_json(extra_files["model.json"])
172226
mds = extra_files["model_def_script.json"]
173227
self._model_def_script = json.loads(mds) if mds else {}
174228
md = extra_files["metadata.json"]
175-
self.metadata = json.loads(md) if md else {}
229+
if not md:
230+
raise ValueError(
231+
f"Invalid .pte file '{model_file}': missing 'metadata.json'"
232+
)
233+
self.metadata = json.loads(md)
234+
235+
model_json_str = extra_files["model.json"]
236+
if model_json_str:
237+
self._init_from_model_json(model_json_str)
238+
else:
239+
self._init_from_metadata()
176240

177241
def _load_pt2(self, model_file: str) -> None:
178-
"""Load a .pt2 (AOTInductor) model file."""
242+
"""Load a .pt2 (AOTInductor) model file.
243+
244+
``extra/model.json`` is optional — it only enables the dpmodel
245+
round-trip (used by ``eval_descriptor``, ``eval_typeebd``, etc.).
246+
Pure AOTI inference (``DeepPot.eval`` / ``dp test`` / ASE
247+
calculator) only needs ``extra/metadata.json``, matching the
248+
contract the C++ ``DeepPotPTExpt`` reader enforces. Backends that
249+
cannot produce ``model.json``.
250+
"""
179251
import zipfile
180252

181253
from torch._inductor import (
@@ -185,21 +257,24 @@ def _load_pt2(self, model_file: str) -> None:
185257
# Read metadata from the .pt2 ZIP archive
186258
with zipfile.ZipFile(model_file, "r") as zf:
187259
names = zf.namelist()
188-
if "extra/model.json" not in names:
260+
if "extra/metadata.json" not in names:
189261
raise ValueError(
190-
f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'"
262+
f"Invalid .pt2 file '{model_file}': missing 'extra/metadata.json'"
191263
)
192-
model_json_str = zf.read("extra/model.json").decode("utf-8")
264+
md = zf.read("extra/metadata.json").decode("utf-8")
265+
model_json_str = ""
266+
if "extra/model.json" in names:
267+
model_json_str = zf.read("extra/model.json").decode("utf-8")
193268
mds = ""
194269
if "extra/model_def_script.json" in names:
195270
mds = zf.read("extra/model_def_script.json").decode("utf-8")
196-
md = ""
197-
if "extra/metadata.json" in names:
198-
md = zf.read("extra/metadata.json").decode("utf-8")
199271

200-
self._init_from_model_json(model_json_str)
272+
self.metadata = json.loads(md)
201273
self._model_def_script = json.loads(mds) if mds else {}
202-
self.metadata = json.loads(md) if md else {}
274+
if model_json_str:
275+
self._init_from_model_json(model_json_str)
276+
else:
277+
self._init_from_metadata()
203278

204279
# Load the AOTInductor model package (.pt2 ZIP archive).
205280
# Uses torch._inductor.aoti_load_package (private API, stable since PyTorch 2.6).
@@ -220,16 +295,29 @@ def get_type_map(self) -> list[str]:
220295

221296
def get_dim_fparam(self) -> int:
222297
"""Get the number (dimension) of frame parameters of this DP."""
223-
return self._dpmodel.get_dim_fparam()
298+
if self._dpmodel is not None:
299+
return self._dpmodel.get_dim_fparam()
300+
return int(self.metadata["dim_fparam"])
224301

225302
def get_dim_aparam(self) -> int:
226303
"""Get the number (dimension) of atomic parameters of this DP."""
227-
return self._dpmodel.get_dim_aparam()
304+
if self._dpmodel is not None:
305+
return self._dpmodel.get_dim_aparam()
306+
return int(self.metadata["dim_aparam"])
228307

229308
@property
230309
def model_type(self) -> type["DeepEvalWrapper"]:
231310
"""The the evaluator of the model type."""
232-
model_output_type = self._dpmodel.model_output_type()
311+
if self._dpmodel is not None:
312+
model_output_type = self._dpmodel.model_output_type()
313+
else:
314+
# Metadata-only mode: derive the output-type set from the
315+
# fitting_output_defs names. `model_output_type()` on a
316+
# dpmodel is the same set — just the base output names, not
317+
# their derived `*_redu` / `*_derv_*` twins.
318+
model_output_type = [
319+
d.name for d in self._model_output_def.def_outp.get_data().values()
320+
]
233321
if "energy" in model_output_type:
234322
return DeepPot
235323
elif "dos" in model_output_type:
@@ -250,7 +338,12 @@ def get_sel_type(self) -> list[int]:
250338
to the result of the model.
251339
If returning an empty list, all atom types are selected.
252340
"""
253-
return self._dpmodel.get_sel_type()
341+
if self._dpmodel is not None:
342+
return self._dpmodel.get_sel_type()
343+
# Metadata-only mode: read the `sel_type` field populated by
344+
# `_collect_metadata`. Missing field → `[]` (every type
345+
# selected), matching the dpmodel default for energy models.
346+
return [int(t) for t in self.metadata.get("sel_type", [])]
254347

255348
def get_numb_dos(self) -> int:
256349
"""Get the number of DOS."""
@@ -266,9 +359,11 @@ def get_has_spin(self) -> bool:
266359

267360
def get_use_spin(self) -> list[bool]:
268361
"""Get the per-type spin usage of this model."""
269-
if getattr(self, "_is_spin", False):
362+
if not getattr(self, "_is_spin", False):
363+
return []
364+
if self._dpmodel is not None:
270365
return self._dpmodel.spin.use_spin.tolist()
271-
return []
366+
return [bool(v) for v in self.metadata.get("use_spin", [])]
272367

273368
def get_ntypes_spin(self) -> int:
274369
"""Get the number of spin atom types of this model. Only used in old implement."""
@@ -423,8 +518,11 @@ def _build_nlist_native(
423518
nframes = coords.shape[0]
424519
natoms = coords.shape[1]
425520
rcut = self.rcut
426-
sel = self._dpmodel.get_sel()
427-
mixed_types = self._dpmodel.mixed_types()
521+
# ``self.sel`` / ``self.mixed_types`` are populated in both
522+
# :meth:`_init_from_model_json` and :meth:`_init_from_metadata`,
523+
# so this works whether or not ``model.json`` was available.
524+
sel = self.sel
525+
mixed_types = self.mixed_types
428526

429527
if cells is not None:
430528
box_input = cells.reshape(nframes, 3, 3)
@@ -535,8 +633,8 @@ def _build_nlist_ase_single(
535633
nlist : np.ndarray, shape (nloc, nsel)
536634
mapping : np.ndarray, shape (nall,)
537635
"""
538-
sel = self._dpmodel.get_sel()
539-
mixed_types = self._dpmodel.mixed_types()
636+
sel = self.sel
637+
mixed_types = self.mixed_types
540638
nsel = sum(sel)
541639

542640
natoms = positions.shape[0]
@@ -995,13 +1093,44 @@ def get_model(self) -> torch.nn.Module:
9951093
return self.exported_module
9961094

9971095
def _is_spin_model(self) -> bool:
998-
"""Check if the underlying dpmodel is a SpinModel."""
1096+
"""Check if the underlying model is a SpinModel.
1097+
1098+
Primary path: the :attr:`_is_spin` attribute set by the loaders
1099+
— this works for both ``model.json`` and metadata-only archives
1100+
(a spin ``.pt2`` carries ``is_spin=true`` in its metadata).
1101+
1102+
Legacy path: ``isinstance(_dpmodel, SpinModel)`` — retained for
1103+
tests that construct a non-spin archive and then swap
1104+
:attr:`_dpmodel` to a :class:`SpinModel` instance after load.
1105+
"""
1106+
if bool(getattr(self, "_is_spin", False)):
1107+
return True
1108+
if self._dpmodel is None:
1109+
return False
9991110
from deepmd.dpmodel.model.spin_model import (
10001111
SpinModel,
10011112
)
10021113

10031114
return isinstance(self._dpmodel, SpinModel)
10041115

1116+
def _require_dpmodel(self, feature: str) -> None:
1117+
"""Guard for features that need a deserialised dpmodel instance.
1118+
1119+
``eval_descriptor`` / ``eval_typeebd`` / ``eval_fitting_last_layer``
1120+
all introspect the dpmodel's internal sub-modules, which requires
1121+
``extra/model.json`` to have been present at load time. Archives
1122+
shipped without ``model.json`` (metadata-only mode) can still run
1123+
the main ``eval`` inference path but cannot expose these hooks.
1124+
"""
1125+
if self._dpmodel is None:
1126+
raise NotImplementedError(
1127+
f"{feature} requires the dpmodel instance, which is only "
1128+
"available when the .pt2 / .pte archive contains "
1129+
"'extra/model.json'. The loaded archive is metadata-only; "
1130+
"re-export with the full dpmodel serialisation to enable "
1131+
"this feature."
1132+
)
1133+
10051134
def eval_typeebd(self) -> np.ndarray:
10061135
"""Evaluate type embedding.
10071136
@@ -1014,7 +1143,11 @@ def eval_typeebd(self) -> np.ndarray:
10141143
------
10151144
KeyError
10161145
If the model has no type embedding networks.
1146+
NotImplementedError
1147+
If the archive was loaded in metadata-only mode.
10171148
"""
1149+
self._require_dpmodel("eval_typeebd")
1150+
10181151
from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
10191152

10201153
model = self._dpmodel
@@ -1058,6 +1191,8 @@ def eval_descriptor(
10581191
np.ndarray
10591192
Descriptor output, shape ``(nframes, nloc, dim_descrpt)``.
10601193
"""
1194+
self._require_dpmodel("eval_descriptor")
1195+
10611196
coords = np.array(coords)
10621197
atom_types = np.array(atom_types, dtype=np.int32)
10631198
if cells is not None:
@@ -1124,6 +1259,8 @@ def eval_fitting_last_layer(
11241259
np.ndarray
11251260
Middle-layer output, shape ``(nframes, nloc, neuron[-1])``.
11261261
"""
1262+
self._require_dpmodel("eval_fitting_last_layer")
1263+
11271264
coords = np.array(coords)
11281265
atom_types = np.array(atom_types, dtype=np.int32)
11291266
if cells is not None:

deepmd/pt_expt/utils/serialization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
247247
"r_differentiable": vdef.r_differentiable,
248248
"c_differentiable": vdef.c_differentiable,
249249
"atomic": vdef.atomic,
250-
"category": vdef.category,
250+
# OutputVariableCategory is an IntEnum; force plain int for
251+
# deterministic JSON serialisation across Python versions.
252+
"category": int(vdef.category),
251253
"r_hessian": vdef.r_hessian,
252254
"magnetic": vdef.magnetic,
253255
"intensive": vdef.intensive,
@@ -263,6 +265,10 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
263265
"has_default_fparam": model.has_default_fparam(),
264266
"default_fparam": model.get_default_fparam(),
265267
"fitting_output_defs": fitting_output_defs,
268+
# sel_type enables `DeepEval.get_sel_type()` without a dpmodel
269+
# round-trip; required for dipole/polar/wfc models in metadata-only
270+
# inference (energy models return []).
271+
"sel_type": [int(t) for t in model.get_sel_type()],
266272
"is_spin": is_spin,
267273
}
268274
if is_spin:

0 commit comments

Comments
 (0)