Skip to content

Commit 1e017bb

Browse files
committed
fix(pt_expt): store .pt2 metadata under model/extra/ to match PyTorch 2.11 layout
``aoti_compile_and_package`` writes every entry of the compiled ``.pt2`` archive under a top-level ``model/`` directory; deepmd-kit then appended its own metadata JSON blobs (``metadata.json``, ``model.json``, ``model_def_script.json``) at the root-level ``extra/`` path via ``zipfile``. Starting with PyTorch 2.11, the strict single-model loader ``torch.export.pt2_archive._package.load_pt2`` refuses archives that carry files outside ``model/``: RuntimeError: [enforce fail at inline_container.cc:340] . file in archive is not in a subdirectory model/: extra/metadata.json ``torch._inductor.package.package.load_package`` catches this error and falls back to the legacy C++ loader, but prints the misleading warning ``Loading outdated pt2 file. Please regenerate your package.`` every time the archive is opened -- even though the archive version itself (``archive_version == '0'``) is already current. Move the deepmd-kit metadata blobs into ``model/extra/`` so that the fast path through ``load_pt2`` accepts the archive cleanly and the misleading warning disappears. A module-level constant ``PT2_EXTRA_PREFIX`` in ``deepmd.pt_expt.utils.serialization`` is the single source of truth for the prefix; both the writer (``_deserialize_to_file_pt2``) and the readers (``_serialize_from_file_pt2``, ``DeepEval._load_pt2``) derive their entry names from it. The C++ reader in ``source/api_cc/src/commonPTExpt.h::read_zip_entry`` needs no changes: it already matches ``entry_name`` as a ``/``-delimited suffix, so ``"extra/metadata.json"`` resolves against both the old root-level and the new ``model/extra/`` location transparently. The ``test_pt2_has_metadata`` assertion in ``source/tests/pt_expt/infer/test_deep_eval.py`` is updated to expect the new paths.
1 parent 1635a0a commit 1e017bb

3 files changed

Lines changed: 72 additions & 33 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ def _init_from_model_json(self, model_json_str: str) -> None:
166166
self._model_output_def = ModelOutputDef(self._dpmodel.atomic_output_def())
167167

168168
def _init_from_metadata(self) -> None:
169-
"""Initialize DeepEval from ``extra/metadata.json`` alone.
169+
"""Initialize DeepEval from ``metadata.json`` alone.
170170
171-
Used when the ``.pt2`` / ``.pte`` archive ships no ``extra/model.json``
171+
Used when the ``.pt2`` / ``.pte`` archive ships no ``model.json``
172172
(e.g. for backends that do not travel through the dpmodel round-trip).
173173
The metadata contract is the same one the C++ ``DeepPotPTExpt``
174174
reader consumes, so anything that validates against the C++ side
@@ -241,33 +241,44 @@ def _load_pte(self, model_file: str) -> None:
241241
def _load_pt2(self, model_file: str) -> None:
242242
"""Load a .pt2 (AOTInductor) model file.
243243
244-
``extra/model.json`` is optional — it only enables the dpmodel
244+
``model.json`` is optional — it only enables the dpmodel
245245
round-trip (used by ``eval_descriptor``, ``eval_typeebd``, etc.).
246246
Pure AOTI inference (``DeepPot.eval`` / ``dp test`` / ASE
247-
calculator) only needs ``extra/metadata.json``, matching the
248-
contract the C++ ``DeepPotPTExpt`` reader enforces. Backends that
249-
cannot produce ``model.json``.
247+
calculator) only needs ``metadata.json``, matching the contract
248+
the C++ ``DeepPotPTExpt`` reader enforces.
249+
250+
Archive entries are located under ``model/extra/`` so that the
251+
PyTorch 2.11 ``load_pt2`` loader accepts the archive without the
252+
"outdated pt2 file" fallback warning.
250253
"""
251254
import zipfile
252255

253256
from torch._inductor import (
254257
aoti_load_package,
255258
)
256259

260+
from deepmd.pt_expt.utils.serialization import (
261+
PT2_EXTRA_PREFIX,
262+
)
263+
264+
md_entry = PT2_EXTRA_PREFIX + "metadata.json"
265+
model_json_entry = PT2_EXTRA_PREFIX + "model.json"
266+
mds_entry = PT2_EXTRA_PREFIX + "model_def_script.json"
267+
257268
# Read metadata from the .pt2 ZIP archive
258269
with zipfile.ZipFile(model_file, "r") as zf:
259270
names = zf.namelist()
260-
if "extra/metadata.json" not in names:
271+
if md_entry not in names:
261272
raise ValueError(
262-
f"Invalid .pt2 file '{model_file}': missing 'extra/metadata.json'"
273+
f"Invalid .pt2 file '{model_file}': missing '{md_entry}'"
263274
)
264-
md = zf.read("extra/metadata.json").decode("utf-8")
275+
md = zf.read(md_entry).decode("utf-8")
265276
model_json_str = ""
266-
if "extra/model.json" in names:
267-
model_json_str = zf.read("extra/model.json").decode("utf-8")
277+
if model_json_entry in names:
278+
model_json_str = zf.read(model_json_entry).decode("utf-8")
268279
mds = ""
269-
if "extra/model_def_script.json" in names:
270-
mds = zf.read("extra/model_def_script.json").decode("utf-8")
280+
if mds_entry in names:
281+
mds = zf.read(mds_entry).decode("utf-8")
271282

272283
self.metadata = json.loads(md)
273284
self._model_def_script = json.loads(mds) if mds else {}
@@ -1118,15 +1129,15 @@ def _require_dpmodel(self, feature: str) -> None:
11181129
11191130
``eval_descriptor`` / ``eval_typeebd`` / ``eval_fitting_last_layer``
11201131
all introspect the dpmodel's internal sub-modules, which requires
1121-
``extra/model.json`` to have been present at load time. Archives
1132+
``model.json`` to have been present at load time. Archives
11221133
shipped without ``model.json`` (metadata-only mode) can still run
11231134
the main ``eval`` inference path but cannot expose these hooks.
11241135
"""
11251136
if self._dpmodel is None:
11261137
raise NotImplementedError(
11271138
f"{feature} requires the dpmodel instance, which is only "
11281139
"available when the .pt2 / .pte archive contains "
1129-
"'extra/model.json'. The loaded archive is metadata-only; "
1140+
"'model.json'. The loaded archive is metadata-only; "
11301141
"re-export with the full dpmodel serialisation to enable "
11311142
"this feature."
11321143
)

deepmd/pt_expt/utils/serialization.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@
1515
traverse_model_dict,
1616
)
1717

18+
# ---------------------------------------------------------------------------
19+
# AOTInductor ``.pt2`` archive layout.
20+
#
21+
# PyTorch 2.11 tightened the single-model ``.pt2`` convention so that every
22+
# entry in the ZIP archive must live under the top-level ``model/`` directory.
23+
# Any stray root-level file makes
24+
# ``torch.export.pt2_archive._package.load_pt2`` raise ``RuntimeError`` at
25+
# load time; the upper-level ``torch._inductor.package.package.load_package``
26+
# then emits a misleading ``Loading outdated pt2 file. Please regenerate
27+
# your package.`` warning and falls back to the legacy C++ loader.
28+
#
29+
# deepmd-kit therefore stores its metadata JSON blobs under ``model/extra/``
30+
# so that the strict ``load_pt2`` loader accepts the archive without
31+
# complaint. The C++ reader (``commonPTExpt.h::read_zip_entry``) resolves
32+
# this layout transparently because it matches ``entry_name`` as a
33+
# ``/``-delimited suffix.
34+
# ---------------------------------------------------------------------------
35+
PT2_EXTRA_PREFIX = "model/extra/"
36+
1837

1938
def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
2039
"""Remove shape-guard assertion nodes from a spin model's exported graph.
@@ -316,21 +335,23 @@ def _serialize_from_file_pte(model_file: str) -> dict:
316335
def _serialize_from_file_pt2(model_file: str) -> dict:
317336
"""Serialize a .pt2 model file to a dictionary.
318337
319-
Reads the model dict stored in the extra/ directory of the .pt2 ZIP archive.
338+
Reads the model dict stored in the ``model/extra/`` directory of the
339+
``.pt2`` ZIP archive.
320340
"""
321341
import zipfile
322342

343+
model_json_entry = PT2_EXTRA_PREFIX + "model.json"
344+
model_def_script_entry = PT2_EXTRA_PREFIX + "model_def_script.json"
323345
with zipfile.ZipFile(model_file, "r") as zf:
324-
if "extra/model.json" not in zf.namelist():
346+
names = zf.namelist()
347+
if model_json_entry not in names:
325348
raise ValueError(
326-
f"Invalid .pt2 file '{model_file}': missing 'extra/model.json'"
349+
f"Invalid .pt2 file '{model_file}': missing '{model_json_entry}'"
327350
)
328-
model_json = zf.read("extra/model.json").decode("utf-8")
351+
model_json = zf.read(model_json_entry).decode("utf-8")
329352
model_def_script_json = ""
330-
if "extra/model_def_script.json" in zf.namelist():
331-
model_def_script_json = zf.read("extra/model_def_script.json").decode(
332-
"utf-8"
333-
)
353+
if model_def_script_entry in names:
354+
model_def_script_json = zf.read(model_def_script_entry).decode("utf-8")
334355
model_dict = json.loads(model_json)
335356
model_dict = _json_to_numpy(model_dict)
336357
if model_def_script_json:
@@ -563,13 +584,20 @@ def _deserialize_to_file_pt2(
563584
# Compile via AOTInductor into a .pt2 package
564585
aoti_compile_and_package(exported, package_path=model_file)
565586

566-
# Embed metadata into the .pt2 ZIP archive
587+
# Embed metadata into the .pt2 ZIP archive. Entries are placed under
588+
# ``model/extra/`` so the strict PyTorch 2.11 ``load_pt2`` loader
589+
# accepts the archive without emitting the "outdated pt2 file"
590+
# fallback warning. See the module-level comment on
591+
# ``PT2_EXTRA_PREFIX`` for the rationale.
567592
model_def_script = data.get("model_def_script") or {}
568593
metadata["output_keys"] = output_keys
569594
with zipfile.ZipFile(model_file, "a") as zf:
570-
zf.writestr("extra/metadata.json", json.dumps(metadata))
571-
zf.writestr("extra/model_def_script.json", json.dumps(model_def_script))
595+
zf.writestr(PT2_EXTRA_PREFIX + "metadata.json", json.dumps(metadata))
596+
zf.writestr(
597+
PT2_EXTRA_PREFIX + "model_def_script.json",
598+
json.dumps(model_def_script),
599+
)
572600
zf.writestr(
573-
"extra/model.json",
601+
PT2_EXTRA_PREFIX + "model.json",
574602
json.dumps(data_for_json, separators=(",", ":")),
575603
)

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -638,14 +638,14 @@ def test_pt2_file_is_zip(self) -> None:
638638
self.assertTrue(zipfile.is_zipfile(self.tmpfile.name))
639639

640640
def test_pt2_has_metadata(self) -> None:
641-
"""The .pt2 ZIP should contain metadata entries."""
641+
"""The .pt2 ZIP should contain metadata entries under ``model/extra/``."""
642642
with zipfile.ZipFile(self.tmpfile.name, "r") as zf:
643643
names = zf.namelist()
644-
self.assertIn("extra/metadata.json", names)
645-
self.assertIn("extra/model_def_script.json", names)
646-
self.assertIn("extra/model.json", names)
647-
self.assertNotIn("extra/output_keys.json", names)
648-
self.assertNotIn("extra/model_params.json", names)
644+
self.assertIn("model/extra/metadata.json", names)
645+
self.assertIn("model/extra/model_def_script.json", names)
646+
self.assertIn("model/extra/model.json", names)
647+
self.assertNotIn("model/extra/output_keys.json", names)
648+
self.assertNotIn("model/extra/model_params.json", names)
649649

650650
def test_eval_consistency(self) -> None:
651651
"""Test that DeepPot.eval gives same results as direct model forward."""

0 commit comments

Comments
 (0)