Skip to content

Commit fe456cc

Browse files
authored
fix: fix the return type of DeepEval.get_model_def_script() for all backends (#5094)
For dpmodel and JAX, previously it returns str, but the type hint is dict. For PyTorch and Paddle, previously it returns dict, but the type hint is str. Now, all backends return dict, and the type hints are also dict. Add a test to test it. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Corrected method return type annotations to improve code accuracy and consistency across inference modules. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 26013cb commit fe456cc

5 files changed

Lines changed: 5 additions & 3 deletions

File tree

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090

9191
model_data = load_dp_model(model_file)
9292
self.dp = BaseModel.deserialize(model_data["model"])
93+
self.dp.model_def_script = json.dumps(model_data.get("model_def_script", {}))
9394
self.rcut = self.dp.get_rcut()
9495
self.type_map = self.dp.get_type_map()
9596
if isinstance(auto_batch_size, bool):

deepmd/jax/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
stablehlo_atomic_virial_no_ghost=model_data["@variables"][
105105
"stablehlo_atomic_virial_no_ghost"
106106
].tobytes(),
107-
model_def_script=model_data["model_def_script"],
107+
model_def_script=json.dumps(model_data["model_def_script"]),
108108
**model_data["constants"],
109109
)
110110
elif model_file.endswith(".savedmodel"):

deepmd/pd/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def eval_typeebd(self) -> np.ndarray:
726726
typeebd = paddle.concat(out, axis=1)
727727
return to_numpy_array(typeebd)
728728

729-
def get_model_def_script(self) -> str:
729+
def get_model_def_script(self) -> dict:
730730
"""Get model definition script."""
731731
return self.model_def_script
732732

deepmd/pt/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def eval_typeebd(self) -> np.ndarray:
684684
typeebd = torch.cat(out, dim=1)
685685
return to_numpy_array(typeebd)
686686

687-
def get_model_def_script(self) -> str:
687+
def get_model_def_script(self) -> dict:
688688
"""Get model definition script."""
689689
return self.model_def_script
690690

source/tests/consistent/io/test_io.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def test_deep_eval(self) -> None:
158158
prefix + backend.suffixes[suffix_idx], reference_data
159159
)
160160
deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx])
161+
self.assertIsInstance(deep_eval.get_model_def_script(), dict)
161162
if deep_eval.get_dim_fparam() > 0:
162163
fparam = np.ones((nframes, deep_eval.get_dim_fparam()))
163164
else:

0 commit comments

Comments
 (0)