Skip to content

Commit add16fa

Browse files
committed
fixup
1 parent 1e017bb commit add16fa

4 files changed

Lines changed: 140 additions & 34 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,10 @@ def _init_from_model_json(self, model_json_str: str) -> None:
138138
self._dpmodel = BaseModel.deserialize(model_data)
139139
self._is_spin = False
140140

141-
self.rcut = self._dpmodel.get_rcut()
142-
self.type_map = self._dpmodel.get_type_map()
143-
# Hoist sel / mixed_types to plain attributes so the inference hot
144-
# path (`_build_nlist_*`) does not need the dpmodel instance. This
145-
# matches what `_init_from_metadata` sets and keeps both code paths
146-
# numerically identical.
147-
self.sel = list(self._dpmodel.get_sel())
148-
self.mixed_types = bool(self._dpmodel.mixed_types())
141+
self._rcut = self._dpmodel.get_rcut()
142+
self._type_map = self._dpmodel.get_type_map()
143+
self._sel = list(self._dpmodel.get_sel())
144+
self._mixed_types = bool(self._dpmodel.mixed_types())
149145
if self._is_spin:
150146
self._model_output_def = ModelOutputDef(
151147
FittingOutputDef(
@@ -177,15 +173,15 @@ def _init_from_metadata(self) -> None:
177173
``self._dpmodel`` is left as ``None`` to signal the metadata-only
178174
mode. Inference does not need it: it runs through
179175
``aoti_load_package`` / the exported module and uses plain
180-
attributes (``self.rcut``, ``self.sel``, ``self.mixed_types``,
176+
attributes (``self._rcut``, ``self._sel``, ``self._mixed_types``,
181177
``self._model_output_def``) for all metadata-level queries.
182178
"""
183179
self._dpmodel = None
184180
self._is_spin = bool(self.metadata.get("is_spin", False))
185-
self.rcut = float(self.metadata["rcut"])
186-
self.type_map = list(self.metadata["type_map"])
187-
self.sel = [int(s) for s in self.metadata["sel"]]
188-
self.mixed_types = bool(self.metadata["mixed_types"])
181+
self._rcut = float(self.metadata["rcut"])
182+
self._type_map = list(self.metadata["type_map"])
183+
self._sel = [int(s) for s in self.metadata["sel"]]
184+
self._mixed_types = bool(self.metadata["mixed_types"])
189185

190186
fitting_defs = []
191187
for vdef in self.metadata["fitting_output_defs"]:
@@ -294,15 +290,15 @@ def _load_pt2(self, model_file: str) -> None:
294290

295291
def get_rcut(self) -> float:
296292
"""Get the cutoff radius of this model."""
297-
return self.rcut
293+
return self._rcut
298294

299295
def get_ntypes(self) -> int:
300296
"""Get the number of atom types of this model."""
301-
return len(self.type_map)
297+
return len(self._type_map)
302298

303299
def get_type_map(self) -> list[str]:
304300
"""Get the type map (element name of the atom types) of this model."""
305-
return self.type_map
301+
return self._type_map
306302

307303
def get_dim_fparam(self) -> int:
308304
"""Get the number (dimension) of frame parameters of this DP."""
@@ -318,7 +314,7 @@ def get_dim_aparam(self) -> int:
318314

319315
@property
320316
def model_type(self) -> type["DeepEvalWrapper"]:
321-
"""The the evaluator of the model type."""
317+
"""The evaluator of the model type."""
322318
if self._dpmodel is not None:
323319
model_output_type = self._dpmodel.model_output_type()
324320
else:
@@ -366,11 +362,11 @@ def get_has_efield(self) -> bool:
366362

367363
def get_has_spin(self) -> bool:
368364
"""Check if the model has spin atom types."""
369-
return getattr(self, "_is_spin", False)
365+
return self._is_spin
370366

371367
def get_use_spin(self) -> list[bool]:
372368
"""Get the per-type spin usage of this model."""
373-
if not getattr(self, "_is_spin", False):
369+
if not self._is_spin:
374370
return []
375371
if self._dpmodel is not None:
376372
return self._dpmodel.spin.use_spin.tolist()
@@ -528,12 +524,9 @@ def _build_nlist_native(
528524
"""
529525
nframes = coords.shape[0]
530526
natoms = coords.shape[1]
531-
rcut = self.rcut
532-
# ``self.sel`` / ``self.mixed_types`` are populated in both
533-
# :meth:`_init_from_model_json` and :meth:`_init_from_metadata`,
534-
# so this works whether or not ``model.json`` was available.
535-
sel = self.sel
536-
mixed_types = self.mixed_types
527+
rcut = self._rcut
528+
sel = self._sel
529+
mixed_types = self._mixed_types
537530

538531
if cells is not None:
539532
box_input = cells.reshape(nframes, 3, 3)
@@ -644,8 +637,8 @@ def _build_nlist_ase_single(
644637
nlist : np.ndarray, shape (nloc, nsel)
645638
mapping : np.ndarray, shape (nall,)
646639
"""
647-
sel = self.sel
648-
mixed_types = self.mixed_types
640+
sel = self._sel
641+
mixed_types = self._mixed_types
649642
nsel = sum(sel)
650643

651644
natoms = positions.shape[0]
@@ -688,7 +681,7 @@ def _build_nlist_ase_single(
688681
ghost_remap[out_mask] = np.arange(nloc, nloc + nghost, dtype=np.int64)
689682

690683
# Build nlist: vectorized CSR-to-dense conversion
691-
rcut = self.rcut
684+
rcut = self._rcut
692685
counts = np.diff(first_neigh)
693686
max_nn = int(counts.max()) if counts.size > 0 else 0
694687

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,13 +551,26 @@ def setUpClass(cls) -> None:
551551
finally:
552552
torch.set_default_device("cuda:9999999")
553553

554+
cls.meta_tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
555+
cls.meta_tmpfile.close()
556+
with (
557+
zipfile.ZipFile(cls.tmpfile.name, "r") as zin,
558+
zipfile.ZipFile(cls.meta_tmpfile.name, "w") as zout,
559+
):
560+
for info in zin.infolist():
561+
if info.filename == "model/extra/model.json":
562+
continue
563+
zout.writestr(info, zin.read(info.filename))
564+
554565
# Also save to .pte for cross-format comparison
555566
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
556567
cls.pte_tmpfile.close()
557568
deserialize_to_file(cls.pte_tmpfile.name, cls.model_data)
558569

559570
# Create DeepPot for .pt2
560571
cls.dp = DeepPot(cls.tmpfile.name)
572+
# Create DeepPot for metadata-only .pt2
573+
cls.dp_meta = DeepPot(cls.meta_tmpfile.name)
561574
# Create DeepPot for .pte reference
562575
cls.dp_pte = DeepPot(cls.pte_tmpfile.name)
563576

@@ -566,6 +579,7 @@ def tearDownClass(cls) -> None:
566579
import os
567580

568581
os.unlink(cls.tmpfile.name)
582+
os.unlink(cls.meta_tmpfile.name)
569583
os.unlink(cls.pte_tmpfile.name)
570584

571585
def test_get_rcut(self) -> None:
@@ -647,6 +661,44 @@ def test_pt2_has_metadata(self) -> None:
647661
self.assertNotIn("model/extra/output_keys.json", names)
648662
self.assertNotIn("model/extra/model_params.json", names)
649663

664+
def test_metadata_only_pt2_has_no_model_json(self) -> None:
665+
"""The metadata-only .pt2 keeps metadata but drops model.json."""
666+
with zipfile.ZipFile(self.meta_tmpfile.name, "r") as zf:
667+
names = zf.namelist()
668+
self.assertIn("model/extra/metadata.json", names)
669+
self.assertNotIn("model/extra/model.json", names)
670+
671+
def test_metadata_only_pt2_accessors_match(self) -> None:
672+
"""Metadata-only .pt2 archives expose the same metadata API."""
673+
full = self.dp.deep_eval
674+
meta = self.dp_meta.deep_eval
675+
self.assertIsNotNone(full._dpmodel)
676+
self.assertIsNone(meta._dpmodel)
677+
self.assertEqual(full.get_rcut(), meta.get_rcut())
678+
self.assertEqual(full.get_ntypes(), meta.get_ntypes())
679+
self.assertEqual(full.get_type_map(), meta.get_type_map())
680+
self.assertEqual(full.get_dim_fparam(), meta.get_dim_fparam())
681+
self.assertEqual(full.get_dim_aparam(), meta.get_dim_aparam())
682+
self.assertEqual(full.get_sel_type(), meta.get_sel_type())
683+
self.assertEqual(full.get_has_spin(), meta.get_has_spin())
684+
self.assertEqual(full.get_use_spin(), meta.get_use_spin())
685+
self.assertIs(full.model_type, meta.model_type)
686+
687+
def test_metadata_only_pt2_eval_parity(self) -> None:
688+
"""Metadata-only .pt2 inference matches the full archive exactly."""
689+
rng = np.random.default_rng(GLOBAL_SEED + 29)
690+
natoms = 5
691+
coords = rng.random((1, natoms, 3)) * 8.0
692+
cells = np.eye(3).reshape(1, 9) * 10.0
693+
atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32)
694+
695+
full_out = self.dp.eval(coords, cells, atom_types, atomic=True)
696+
meta_out = self.dp_meta.eval(coords, cells, atom_types, atomic=True)
697+
698+
self.assertEqual(len(full_out), len(meta_out))
699+
for ref, test in zip(full_out, meta_out, strict=True):
700+
np.testing.assert_array_equal(test, ref)
701+
650702
def test_eval_consistency(self) -> None:
651703
"""Test that DeepPot.eval gives same results as direct model forward."""
652704
rng = np.random.default_rng(GLOBAL_SEED)

source/tests/pt_expt/infer/test_deep_eval_metadata_only.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
``.pte`` (the fast path; ``.pt2`` AOTInductor compilation is too
1313
heavy for a routine unit test).
1414
2. Read back that ``.pte`` and record the reference outputs.
15-
3. Rewrite the archive byte-for-byte except for the ``extra/model.json``
16-
entry, producing a metadata-only variant.
15+
3. Copy all archive entries except ``extra/model.json`` into a
16+
metadata-only variant.
1717
4. Load the metadata-only archive via ``DeepPot`` and assert that the
1818
metadata-level accessors and the numeric ``eval`` result are
1919
**bitwise identical** to the reference.
@@ -151,10 +151,10 @@ def test_internal_attributes_match(self) -> None:
151151
"""The hot-path attributes hoisted in both init paths must agree."""
152152
full = self.dp_full.deep_eval
153153
meta = self.dp_meta.deep_eval
154-
self.assertEqual(list(full.sel), list(meta.sel))
155-
self.assertEqual(bool(full.mixed_types), bool(meta.mixed_types))
156-
self.assertEqual(full.rcut, meta.rcut)
157-
self.assertEqual(list(full.type_map), list(meta.type_map))
154+
self.assertEqual(list(full._sel), list(meta._sel))
155+
self.assertEqual(bool(full._mixed_types), bool(meta._mixed_types))
156+
self.assertEqual(full._rcut, meta._rcut)
157+
self.assertEqual(list(full._type_map), list(meta._type_map))
158158

159159
def test_dpmodel_presence(self) -> None:
160160
"""``_dpmodel`` is the single signal that separates the two modes."""

source/tests/pt_expt/infer/test_deep_eval_spin.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import copy
99
import os
1010
import tempfile
11+
import zipfile
1112

1213
import numpy as np
1314
import pytest
@@ -105,6 +106,37 @@
105106
BOX = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], dtype=np.float64)
106107

107108

109+
def _strip_extra_model_json(src: str, dst: str) -> None:
110+
"""Copy ``src`` to ``dst`` dropping any ``extra/model.json`` entry."""
111+
with zipfile.ZipFile(src, "r") as zin, zipfile.ZipFile(dst, "w") as zout:
112+
for info in zin.infolist():
113+
if info.filename.endswith("extra/model.json"):
114+
continue
115+
zout.writestr(info, zin.read(info.filename))
116+
117+
118+
def _assert_fitting_output_defs_match(full_eval, meta_eval) -> None:
119+
"""Assert that metadata rebuilds the same fitting output definitions."""
120+
full_defs = full_eval._model_output_def.def_outp.get_data()
121+
meta_defs = meta_eval._model_output_def.def_outp.get_data()
122+
assert full_defs.keys() == meta_defs.keys()
123+
attrs = (
124+
"shape",
125+
"reducible",
126+
"r_differentiable",
127+
"c_differentiable",
128+
"atomic",
129+
"category",
130+
"r_hessian",
131+
"magnetic",
132+
"intensive",
133+
)
134+
for name, full_def in full_defs.items():
135+
meta_def = meta_defs[name]
136+
for attr in attrs:
137+
assert getattr(meta_def, attr) == getattr(full_def, attr)
138+
139+
108140
def _build_reference():
109141
"""Build pt_expt model and run eager reference inference.
110142
@@ -163,6 +195,9 @@ def spin_model_files():
163195
finally:
164196
torch.set_default_device(prev)
165197
files[ext] = path
198+
meta_path = os.path.join(tmpdir, "spin_test_metadata_only.pte")
199+
_strip_extra_model_json(files[".pte"], meta_path)
200+
files[".pte.meta"] = meta_path
166201
yield files, ref_pbc, ref_nopbc
167202
for path in files.values():
168203
if os.path.exists(path):
@@ -341,6 +376,32 @@ def test_eval_nopbc_nonatomic(self, spin_model_files, ext) -> None:
341376
)
342377

343378

379+
class TestSpinMetadataOnly:
380+
"""Test metadata-only spin model inference through DeepPot."""
381+
382+
def test_metadata_only_spin_pte_parity(self, spin_model_files) -> None:
383+
"""Metadata-only spin .pte matches full archive metadata and outputs."""
384+
from deepmd.infer import (
385+
DeepPot,
386+
)
387+
388+
files, _, _ = spin_model_files
389+
full_dp = DeepPot(files[".pte"])
390+
meta_dp = DeepPot(files[".pte.meta"])
391+
392+
assert meta_dp.has_spin == full_dp.has_spin
393+
assert meta_dp.use_spin == full_dp.use_spin
394+
assert meta_dp.get_ntypes_spin() == full_dp.get_ntypes_spin()
395+
_assert_fitting_output_defs_match(full_dp.deep_eval, meta_dp.deep_eval)
396+
397+
full_out = full_dp.eval(COORD, BOX, ATYPE, atomic=True, spin=SPIN)
398+
meta_out = meta_dp.eval(COORD, BOX, ATYPE, atomic=True, spin=SPIN)
399+
400+
assert len(full_out) == len(meta_out)
401+
for ref, test in zip(full_out, meta_out, strict=True):
402+
np.testing.assert_array_equal(test, ref)
403+
404+
344405
SPIN_FPARAM_CONFIG = copy.deepcopy(SPIN_CONFIG)
345406
SPIN_FPARAM_CONFIG["fitting_net"]["numb_fparam"] = 1
346407
SPIN_FPARAM_CONFIG["fitting_net"]["default_fparam"] = [0.5]

0 commit comments

Comments
 (0)