Skip to content

Commit 1694360

Browse files
committed
refactor(deepeval): implement backend serialize via model.serialize
Switch DeepEval.serialize() to delegate to DeepEvalBackend.serialize(), and implement serialize() in each backend by calling the underlying model's serialize(). Also move Node import in dp show to module top-level. Authored by OpenClaw (model: gpt-5.2)
1 parent b695aaa commit 1694360

9 files changed

Lines changed: 111 additions & 41 deletions

File tree

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def get_model_def_script(self) -> dict:
406406
"""Get model definition script."""
407407
return json.loads(self.dp.get_model_def_script())
408408

409+
def serialize(self) -> dict[str, Any]:
410+
model = self.dp
411+
data: dict[str, Any] = {
412+
"backend": "DPModel",
413+
"model": model.serialize(),
414+
"model_def_script": self.get_model_def_script(),
415+
"@variables": {},
416+
}
417+
if model.get_min_nbor_dist() is not None:
418+
data["@variables"]["min_nbor_dist"] = model.get_min_nbor_dist()
419+
return data
420+
409421
def get_observed_types(self) -> dict:
410422
"""Get observed types (elements) of the model during data statistics.
411423

deepmd/entrypoints/show.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
Any,
55
)
66

7+
from deepmd.dpmodel.utils.serialization import (
8+
Node,
9+
)
710
from deepmd.infer.deep_eval import (
811
DeepEval,
912
)
@@ -138,10 +141,6 @@ def show(
138141
log.info(f"Observed types: {observed_types['observed_type']} ")
139142

140143
if "serialization-tree" in ATTRIBUTES:
141-
from deepmd.dpmodel.utils.serialization import (
142-
Node,
143-
)
144-
145144
data = model.serialize()
146145
if "model" not in data:
147146
raise RuntimeError("Serialized model data does not contain key 'model'.")

deepmd/infer/deep_eval.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,16 @@ def get_model(self) -> Any:
361361
The model module implemented by the deep learning framework.
362362
"""
363363

364+
@abstractmethod
365+
def serialize(self) -> dict[str, Any]:
366+
"""Serialize the loaded model to a backend-unified dictionary.
367+
368+
Returns
369+
-------
370+
dict
371+
Serialized model data. Must include key ``"model"``.
372+
"""
373+
364374

365375
class DeepEval(ABC):
366376
"""High-level Deep Evaluator interface.
@@ -422,43 +432,8 @@ def output_def(self) -> ModelOutputDef:
422432
"""Returns the output variable definitions."""
423433

424434
def serialize(self) -> dict[str, Any]:
425-
"""Serialize the model file to a dictionary (backend-unified).
426-
427-
This is a convenience wrapper around backend-specific serialization
428-
hooks, intended for unified model inspection / display.
429-
430-
Returns
431-
-------
432-
dict
433-
Serialized model data (must include key ``"model"``).
434-
435-
Raises
436-
------
437-
NotImplementedError
438-
If the detected backend does not support IO serialization.
439-
"""
440-
backend_cls = Backend.detect_backend_by_model(self.model_file)
441-
442-
# internal alias backend: resolve to a verified local file first
443-
if getattr(backend_cls, "name", "").lower() == "pretrained":
444-
from deepmd.pretrained.deep_eval import (
445-
parse_pretrained_alias,
446-
)
447-
from deepmd.pretrained.download import (
448-
resolve_model_path,
449-
)
450-
451-
model_name = parse_pretrained_alias(self.model_file)
452-
resolved = str(resolve_model_path(model_name))
453-
backend_cls = Backend.detect_backend_by_model(resolved)
454-
return backend_cls().serialize_hook(resolved)
455-
456-
if not (backend_cls.features & Backend.Feature.IO):
457-
raise NotImplementedError(
458-
f"Backend '{backend_cls.name}' does not support serialization."
459-
)
460-
461-
return backend_cls().serialize_hook(self.model_file)
435+
"""Serialize the loaded model to a backend-unified dictionary."""
436+
return self.deep_eval.serialize()
462437

463438
def get_rcut(self) -> float:
464439
"""Get the cutoff radius of this model."""

deepmd/jax/infer/deep_eval.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from deepmd.jax.common import (
4848
to_jax_array,
4949
)
50+
from deepmd.jax.env import (
51+
jax,
52+
)
5053
from deepmd.jax.model.hlo import (
5154
HLO,
5255
)
@@ -187,6 +190,19 @@ def get_ntypes_spin(self) -> int:
187190
"""Get the number of spin atom types of this model."""
188191
return 0
189192

193+
def serialize(self) -> dict[str, Any]:
194+
model = self.dp
195+
data: dict[str, Any] = {
196+
"backend": "JAX",
197+
"jax_version": jax.__version__,
198+
"model": model.serialize(),
199+
"model_def_script": json.loads(model.get_model_def_script()),
200+
"@variables": {},
201+
}
202+
if model.get_min_nbor_dist() is not None:
203+
data["@variables"]["min_nbor_dist"] = model.get_min_nbor_dist()
204+
return data
205+
190206
def eval(
191207
self,
192208
coords: np.ndarray,

deepmd/pd/infer/deep_eval.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,18 @@ def get_model_def_script(self) -> dict:
730730
"""Get model definition script."""
731731
return self.model_def_script
732732

733+
def serialize(self) -> dict[str, Any]:
734+
model = self.dp.model["Default"]
735+
data: dict[str, Any] = {
736+
"backend": "Paddle",
737+
"model": model.serialize(),
738+
"model_def_script": self.get_model_def_script(),
739+
"@variables": {},
740+
}
741+
if model.get_min_nbor_dist() is not None:
742+
data["@variables"]["min_nbor_dist"] = model.get_min_nbor_dist()
743+
return data
744+
733745
def get_model_size(self) -> dict:
734746
"""Get model parameter count.
735747

deepmd/pretrained/deep_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,6 @@ def get_ntypes_spin(self) -> int:
184184

185185
def get_model(self) -> Any:
186186
return self._backend.get_model()
187+
188+
def serialize(self) -> dict[str, Any]:
189+
return self._backend.serialize()

deepmd/pt/infer/deep_eval.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,18 @@ def get_model_def_script(self) -> dict:
702702
"""Get model definition script."""
703703
return self.model_def_script
704704

705+
def serialize(self) -> dict[str, Any]:
706+
model = self.dp.model["Default"]
707+
data: dict[str, Any] = {
708+
"backend": "PyTorch",
709+
"model": model.serialize(),
710+
"model_def_script": self.get_model_def_script(),
711+
"@variables": {},
712+
}
713+
if model.get_min_nbor_dist() is not None:
714+
data["@variables"]["min_nbor_dist"] = model.get_min_nbor_dist()
715+
return data
716+
705717
def get_model_size(self) -> dict:
706718
"""Get model parameter count.
707719

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,13 @@ def get_model_def_script(self) -> dict:
665665
"""Get model definition script."""
666666
return self.metadata
667667

668+
def serialize(self) -> dict[str, Any]:
669+
from deepmd.pt_expt.utils.serialization import (
670+
serialize_from_file,
671+
)
672+
673+
return serialize_from_file(self.model_path)
674+
668675
def get_model(self) -> torch.nn.Module:
669676
"""Get the exported model module.
670677

deepmd/tf/infer/deep_eval.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(
110110
input_map=input_map,
111111
)
112112
self.load_prefix = load_prefix
113+
self.model_file = model_file
113114

114115
# graph_compatable should be called after graph and prefix are set
115116
if not self._graph_compatable():
@@ -1121,6 +1122,38 @@ def get_model_def_script(self) -> dict:
11211122
model_def_script = script.decode("utf-8")
11221123
return json.loads(model_def_script)["model"]
11231124

1125+
def serialize(self) -> dict[str, Any]:
1126+
from deepmd.tf.model.model import (
1127+
Model,
1128+
)
1129+
from deepmd.tf.utils.graph import (
1130+
load_graph_def,
1131+
)
1132+
1133+
graph, graph_def = load_graph_def(str(self.model_file))
1134+
1135+
model_def_script = self.get_model_def_script()
1136+
model = Model(**model_def_script)
1137+
# important! must be called before serialize
1138+
model.init_variables(graph=graph, graph_def=graph_def)
1139+
model_dict = model.serialize()
1140+
1141+
data: dict[str, Any] = {
1142+
"backend": "TensorFlow",
1143+
"tf_version": tf.__version__,
1144+
"model": model_dict,
1145+
"model_def_script": model_def_script,
1146+
}
1147+
try:
1148+
t_min_nbor_dist = self._get_tensor("train_attr/min_nbor_dist:0")
1149+
except KeyError:
1150+
pass
1151+
else:
1152+
[min_nbor_dist] = run_sess(self.sess, [t_min_nbor_dist], feed_dict={})
1153+
data.setdefault("@variables", {})
1154+
data["@variables"]["min_nbor_dist"] = float(min_nbor_dist)
1155+
return data
1156+
11241157
def get_model(self) -> "tf.Graph":
11251158
"""Get the TensorFlow graph.
11261159
@@ -1172,6 +1205,7 @@ def __init__(
11721205
input_map=input_map,
11731206
)
11741207
self.load_prefix = load_prefix
1208+
self.model_file = model_file
11751209

11761210
# graph_compatable should be called after graph and prefix are set
11771211
if not self._graph_compatable():

0 commit comments

Comments
 (0)