Skip to content

Commit 7ba9fad

Browse files
Copilotnjzjz
andcommitted
Addressing PR comments
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 233e0ea commit 7ba9fad

6 files changed

Lines changed: 15 additions & 7 deletions

File tree

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def get_model_def_script(self) -> dict:
393393
"""Get model definition script."""
394394
return json.loads(self.dp.get_model_def_script())
395395

396-
def get_model(self):
396+
def get_model(self) -> "BaseModel":
397397
"""Get the dpmodel BaseModel.
398398
399399
Returns

deepmd/infer/deep_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def get_observed_types(self) -> dict:
342342
raise NotImplementedError("Not implemented in this backend.")
343343

344344
@abstractmethod
345-
def get_model(self):
345+
def get_model(self) -> Any:
346346
"""Get the model module implemented by the deep learning framework.
347347
348348
For PyTorch, this returns the nn.Module. For Paddle, this returns
@@ -700,7 +700,7 @@ def get_observed_types(self) -> dict:
700700
"""Get observed types (elements) of the model during data statistics."""
701701
return self.deep_eval.get_observed_types()
702702

703-
def get_model(self):
703+
def get_model(self) -> Any:
704704
"""Get the model module implemented by the deep learning framework.
705705
706706
For PyTorch, this returns the nn.Module. For Paddle, this returns

deepmd/jax/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def get_model_def_script(self) -> dict:
421421
"""Get model definition script."""
422422
return json.loads(self.dp.get_model_def_script())
423423

424-
def get_model(self):
424+
def get_model(self) -> Any:
425425
"""Get the JAX model as BaseModel.
426426
427427
Returns

deepmd/pd/infer/deep_eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646
if TYPE_CHECKING:
4747
import ase.neighborlist
4848

49+
from deepmd.pd.model.model.model import (
50+
BaseModel,
51+
)
52+
4953

5054
class DeepEval(DeepEvalBackend):
5155
"""Paddle backend implementation of DeepEval.
@@ -506,7 +510,7 @@ def get_model_size(self) -> dict:
506510
"total": sum_param_des + sum_param_fit,
507511
}
508512

509-
def get_model(self):
513+
def get_model(self) -> "BaseModel":
510514
"""Get the Paddle model.
511515
512516
Returns

deepmd/pt/infer/deep_eval.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@
7575
if TYPE_CHECKING:
7676
import ase.neighborlist
7777

78+
from deepmd.pt.model.model.model import (
79+
BaseModel,
80+
)
81+
7882
log = logging.getLogger(__name__)
7983

8084

@@ -706,7 +710,7 @@ def get_observed_types(self) -> dict:
706710
"observed_type": sort_element_type(observed_type_list),
707711
}
708712

709-
def get_model(self):
713+
def get_model(self) -> "BaseModel":
710714
"""Get the PyTorch model.
711715
712716
Returns

deepmd/tf/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,7 +1126,7 @@ def get_model_def_script(self) -> dict:
11261126
model_def_script = script.decode("utf-8")
11271127
return json.loads(model_def_script)["model"]
11281128

1129-
def get_model(self):
1129+
def get_model(self) -> "tf.Graph":
11301130
"""Get the TensorFlow graph.
11311131
11321132
Returns

0 commit comments

Comments
 (0)