diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 605fe72d62..3bd0f435e8 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -90,6 +90,7 @@ def __init__( model_data = load_dp_model(model_file) self.dp = BaseModel.deserialize(model_data["model"]) + self.dp.model_def_script = json.dumps(model_data.get("model_def_script", {})) self.rcut = self.dp.get_rcut() self.type_map = self.dp.get_type_map() if isinstance(auto_batch_size, bool): diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index ee605e4c07..1e29ee1c78 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -104,7 +104,7 @@ def __init__( stablehlo_atomic_virial_no_ghost=model_data["@variables"][ "stablehlo_atomic_virial_no_ghost" ].tobytes(), - model_def_script=model_data["model_def_script"], + model_def_script=json.dumps(model_data["model_def_script"]), **model_data["constants"], ) elif model_file.endswith(".savedmodel"): diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index 6715d6d0a0..67c435ab3f 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -726,7 +726,7 @@ def eval_typeebd(self) -> np.ndarray: typeebd = paddle.concat(out, axis=1) return to_numpy_array(typeebd) - def get_model_def_script(self) -> str: + def get_model_def_script(self) -> dict: """Get model definition script.""" return self.model_def_script diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 02cdb702ef..2726b61152 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -684,7 +684,7 @@ def eval_typeebd(self) -> np.ndarray: typeebd = torch.cat(out, dim=1) return to_numpy_array(typeebd) - def get_model_def_script(self) -> str: + def get_model_def_script(self) -> dict: """Get model definition script.""" return self.model_def_script diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index d6b7348a91..7ae9af6891 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -158,6 +158,7 @@ def test_deep_eval(self) -> None: prefix + backend.suffixes[suffix_idx], reference_data ) deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx]) + self.assertIsInstance(deep_eval.get_model_def_script(), dict) if deep_eval.get_dim_fparam() > 0: fparam = np.ones((nframes, deep_eval.get_dim_fparam())) else: