Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
c8de6a6
feat(show): add serialization-tree via DeepEval.serialize
njzjz-bot Mar 16, 2026
b695aaa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
1694360
refactor(deepeval): implement backend serialize via model.serialize
njzjz-bot Mar 16, 2026
7675388
fix(show): align backend serialize output contracts
njzjz-bot Apr 7, 2026
a8a80f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2026
c18814c
refactor(show): decouple serialization tree from deep eval wrapper
njzjz-bot Apr 8, 2026
5678d78
refactor(deepeval): serialize model tree only
njzjz-bot Apr 8, 2026
8b48f46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
5e97073
test(io): cover deep eval serialization in consistent io
njzjz-bot Apr 8, 2026
8bf3176
fix(review): handle paddle static serialize and io assert
njzjz-bot Apr 8, 2026
6cf9dbf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
53f6f12
fix(ci): handle savedmodel and torchscript serialize
njzjz-bot Apr 8, 2026
a2431e8
fix(ci): guard pt serialize fallback shape
njzjz-bot Apr 8, 2026
39aab42
fix(test): pt_expt serialize fallback returns model tree with @class key
njzjz-bot Apr 8, 2026
67f4d1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2026
024a058
fix(ci): revert pt deep eval mixed pte serialization path
njzjz-bot Apr 8, 2026
b951333
fix(ci): restore backend-specific serialize fallbacks
njzjz-bot Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return json.loads(self.dp.get_model_def_script())

def serialize(self) -> dict[str, Any]:
model = self.dp
return model.serialize()

def get_observed_types(self) -> dict:
"""Get observed types (elements) of the model during data statistics.

Expand Down
7 changes: 7 additions & 0 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Any,
)

from deepmd.dpmodel.utils.serialization import (
Node,
)
from deepmd.infer.deep_eval import (
DeepEval,
)
Expand Down Expand Up @@ -136,3 +139,7 @@ def show(
observed_types = model.get_observed_types()
log.info(f"Number of observed types: {observed_types['type_num']} ")
log.info(f"Observed types: {observed_types['observed_type']} ")

if "serialization-tree" in ATTRIBUTES:
root = Node.deserialize(model.serialize())
log.info("Model serialization tree:\n" + str(root))
15 changes: 15 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ def get_model(self) -> Any:
The model module implemented by the deep learning framework.
"""

@abstractmethod
def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model structure only.

Returns
-------
dict
Serialized model tree that can be consumed by ``Node.deserialize``.
"""


class DeepEval(ABC):
"""High-level Deep Evaluator interface.
Expand Down Expand Up @@ -404,6 +414,7 @@ def __init__(
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Any,
) -> None:
self.model_file = model_file
self.deep_eval = DeepEvalBackend(
model_file,
self.output_def,
Expand All @@ -420,6 +431,10 @@ def __init__(
def output_def(self) -> ModelOutputDef:
"""Returns the output variable definitions."""

def serialize(self) -> dict[str, Any]:
"""Serialize the loaded model structure only."""
return self.deep_eval.serialize()

def get_rcut(self) -> float:
"""Get the cutoff radius of this model."""
return self.deep_eval.get_rcut()
Expand Down
10 changes: 10 additions & 0 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,16 @@ def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
return 0

def serialize(self) -> dict[str, Any]:
from deepmd.jax.utils.serialization import (
serialize_from_file,
)

data = serialize_from_file(self.model_path)
if "model" not in data:
raise RuntimeError("Serialized model data does not contain key 'model'.")
return data["model"]

def eval(
self,
coords: np.ndarray,
Expand Down
10 changes: 9 additions & 1 deletion deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,13 @@ def convert_str_to_int_key(item: dict) -> None:
data.pop("constants")
data["@variables"].pop("stablehlo")
return data
elif model_file.endswith(".savedmodel"):
from deepmd.tf.utils.serialization import (
serialize_from_file as serialize_savedmodel,
)

return serialize_savedmodel(model_file)
else:
raise ValueError("JAX backend only supports converting .jax directory")
raise ValueError(
"JAX backend only supports converting .jax directory, .hlo, and .savedmodel"
)
1 change: 1 addition & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ def main_parser() -> argparse.ArgumentParser:
"fitting-net",
"size",
"observed-type",
"serialization-tree",
],
nargs="+",
)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,12 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.model_def_script

def serialize(self) -> dict[str, Any]:
model = (
self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp
)
return model.serialize()
Comment thread
njzjz-bot marked this conversation as resolved.

def get_model_size(self) -> dict:
"""Get model parameter count.

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pretrained/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ def get_ntypes_spin(self) -> int:

def get_model(self) -> Any:
return self._backend.get_model()

def serialize(self) -> dict[str, Any]:
return self._backend.serialize()
11 changes: 11 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,17 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.model_def_script

def serialize(self) -> dict[str, Any]:
model = self.dp.model["Default"]
if hasattr(model, "serialize"):
return model.serialize()

from deepmd.pt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file(self.model_path)["model"]

def get_model_size(self) -> dict:
"""Get model parameter count.

Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,14 @@ def get_model_def_script(self) -> dict:
"""Get model definition script."""
return self.metadata

def serialize(self) -> dict[str, Any]:
from deepmd.pt_expt.utils.serialization import (
serialize_from_file,
)

data = serialize_from_file(self.model_path)
return data["model"] if isinstance(data, dict) and "model" in data else data

def get_model(self) -> torch.nn.Module:
"""Get the exported model module.

Expand Down
18 changes: 18 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down Expand Up @@ -1121,6 +1122,22 @@ def get_model_def_script(self) -> dict:
model_def_script = script.decode("utf-8")
return json.loads(model_def_script)["model"]

def serialize(self) -> dict[str, Any]:
from deepmd.tf.model.model import (
Model,
)
from deepmd.tf.utils.graph import (
load_graph_def,
)

graph, graph_def = load_graph_def(str(self.model_file))

model_def_script = self.get_model_def_script()
model = Model(**model_def_script)
# important! must be called before serialize
model.init_variables(graph=graph, graph_def=graph_def)
return model.serialize()

def get_model(self) -> "tf.Graph":
"""Get the TensorFlow graph.

Expand Down Expand Up @@ -1172,6 +1189,7 @@ def __init__(
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down
9 changes: 5 additions & 4 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ def test_deep_eval(self) -> None:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(
prefix + backend.suffixes[suffix_idx], reference_data
)
deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx])
model_file = prefix + backend.suffixes[suffix_idx]
self.save_data_to_model(model_file, reference_data)
deep_eval = DeepEval(model_file)
self.assertIsInstance(deep_eval.get_model_def_script(), dict)
serialized_data = self.get_data_from_model(model_file)
np.testing.assert_equal(deep_eval.serialize(), serialized_data["model"])
if deep_eval.get_dim_fparam() > 0:
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
else:
Expand Down
6 changes: 6 additions & 0 deletions source/tests/pt_expt/infer/test_deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ def test_get_model_def_script(self) -> None:
self.assertAlmostEqual(mds["rcut"], self.rcut)
self.assertEqual(mds["sel"], list(self.sel))

def test_serialize_returns_model_tree(self) -> None:
data = self.dp.deep_eval.serialize()
self.assertEqual(data["@class"], self.model.serialize()["@class"])
self.assertEqual(data["type"], self.model.serialize()["type"])
self.assertEqual(data, serialize_from_file(self.tmpfile.name))

def test_eval_consistency(self) -> None:
"""Test that DeepPot.eval gives same results as direct model forward."""
rng = np.random.default_rng(GLOBAL_SEED)
Expand Down
29 changes: 29 additions & 0 deletions source/tests/test_entrypoint_show_serialization_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from unittest.mock import (
patch,
)

from deepmd.entrypoints.show import (
show,
)


class TestShowSerializationTree(unittest.TestCase):
def test_serialization_tree_uses_deep_eval_model_payload(self) -> None:
with (
patch("deepmd.entrypoints.show.DeepEval") as mock_deep_eval,
patch("deepmd.entrypoints.show.Node.deserialize") as mock_deserialize,
patch("deepmd.entrypoints.show.log.info") as mock_log_info,
):
model = mock_deep_eval.return_value
model.get_model_def_script.return_value = {"type_map": ["H", "O"]}
model.get_model_size.return_value = {}
model.serialize.return_value = {"@class": "MockModel"}
mock_deserialize.return_value = "ROOT"

show(INPUT="mock.pte", ATTRIBUTES=["serialization-tree"])

model.serialize.assert_called_once_with()
mock_deserialize.assert_called_once_with({"@class": "MockModel"})
mock_log_info.assert_any_call("Model serialization tree:\nROOT")
Loading