Skip to content

Commit 7675388

Browse files
committed
fix(show): align backend serialize output contracts
Route JAX DeepEval serialization through the existing file-based serializer so .hlo and .savedmodel models follow the supported path instead of calling unimplemented model-level serialize() methods. Also add the missing pt_version field to the PyTorch backend serializer and wrap pt_expt serialization in the backend-unified payload expected by dp show serialization-tree. Add a targeted pt_expt serialization contract test. Authored by OpenClaw (model: gpt-5.4)
1 parent 1694360 commit 7675388

4 files changed

Lines changed: 23 additions & 12 deletions

File tree

deepmd/jax/infer/deep_eval.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,11 @@ def get_ntypes_spin(self) -> int:
191191
return 0
192192

193193
def serialize(self) -> dict[str, Any]:
194-
model = self.dp
195-
data: dict[str, Any] = {
196-
"backend": "JAX",
197-
"jax_version": jax.__version__,
198-
"model": model.serialize(),
199-
"model_def_script": json.loads(model.get_model_def_script()),
200-
"@variables": {},
201-
}
202-
if model.get_min_nbor_dist() is not None:
203-
data["@variables"]["min_nbor_dist"] = model.get_min_nbor_dist()
204-
return data
194+
from deepmd.jax.utils.serialization import (
195+
serialize_from_file,
196+
)
197+
198+
return serialize_from_file(self.model_path)
205199

206200
def eval(
207201
self,

deepmd/pt/infer/deep_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def serialize(self) -> dict[str, Any]:
706706
model = self.dp.model["Default"]
707707
data: dict[str, Any] = {
708708
"backend": "PyTorch",
709+
"pt_version": str(torch.__version__),
709710
"model": model.serialize(),
710711
"model_def_script": self.get_model_def_script(),
711712
"@variables": {},

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,13 @@ def serialize(self) -> dict[str, Any]:
670670
serialize_from_file,
671671
)
672672

673-
return serialize_from_file(self.model_path)
673+
model_dict = serialize_from_file(self.model_path)
674+
return {
675+
"backend": "PyTorch Exportable",
676+
"model": model_dict,
677+
"model_def_script": self.get_model_def_script(),
678+
"@variables": {},
679+
}
674680

675681
def get_model(self) -> torch.nn.Module:
676682
"""Get the exported model module.

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ def test_get_model_def_script(self) -> None:
109109
self.assertAlmostEqual(mds["rcut"], self.rcut)
110110
self.assertEqual(mds["sel"], list(self.sel))
111111

112+
def test_serialize_contract(self) -> None:
113+
data = self.dp.deep_eval.serialize()
114+
self.assertEqual(data["backend"], "PyTorch Exportable")
115+
self.assertIn("model", data)
116+
self.assertIn("model_def_script", data)
117+
self.assertIn("@variables", data)
118+
self.assertIsInstance(data["@variables"], dict)
119+
self.assertEqual(data["model_def_script"]["type_map"], self.type_map)
120+
self.assertEqual(data["model"], serialize_from_file(self.tmpfile.name))
121+
112122
def test_eval_consistency(self) -> None:
113123
"""Test that DeepPot.eval gives same results as direct model forward."""
114124
rng = np.random.default_rng(GLOBAL_SEED)

0 commit comments

Comments
 (0)