File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -393,7 +393,7 @@ def get_model_def_script(self) -> dict:
393393 """Get model definition script."""
394394 return json .loads (self .dp .get_model_def_script ())
395395
396- def get_model (self ):
396+ def get_model (self ) -> "BaseModel" :
397397 """Get the dpmodel BaseModel.
398398
399399 Returns
Original file line number Diff line number Diff line change @@ -342,7 +342,7 @@ def get_observed_types(self) -> dict:
342342 raise NotImplementedError ("Not implemented in this backend." )
343343
344344 @abstractmethod
345- def get_model (self ):
345+ def get_model (self ) -> Any :
346346 """Get the model module implemented by the deep learning framework.
347347
348348 For PyTorch, this returns the nn.Module. For Paddle, this returns
@@ -700,7 +700,7 @@ def get_observed_types(self) -> dict:
700700 """Get observed types (elements) of the model during data statistics."""
701701 return self .deep_eval .get_observed_types ()
702702
703- def get_model (self ):
703+ def get_model (self ) -> Any :
704704 """Get the model module implemented by the deep learning framework.
705705
706706 For PyTorch, this returns the nn.Module. For Paddle, this returns
Original file line number Diff line number Diff line change @@ -421,7 +421,7 @@ def get_model_def_script(self) -> dict:
421421 """Get model definition script."""
422422 return json .loads (self .dp .get_model_def_script ())
423423
424- def get_model (self ):
424+ def get_model (self ) -> Any :
425425 """Get the JAX model as BaseModel.
426426
427427 Returns
Original file line number Diff line number Diff line change 4646if TYPE_CHECKING :
4747 import ase .neighborlist
4848
49+ from deepmd .pd .model .model .model import (
50+ BaseModel ,
51+ )
52+
4953
5054class DeepEval (DeepEvalBackend ):
5155 """Paddle backend implementation of DeepEval.
@@ -506,7 +510,7 @@ def get_model_size(self) -> dict:
506510 "total" : sum_param_des + sum_param_fit ,
507511 }
508512
509- def get_model (self ):
513+ def get_model (self ) -> "BaseModel" :
510514 """Get the Paddle model.
511515
512516 Returns
Original file line number Diff line number Diff line change 7575if TYPE_CHECKING :
7676 import ase .neighborlist
7777
78+ from deepmd .pt .model .model .model import (
79+ BaseModel ,
80+ )
81+
7882log = logging .getLogger (__name__ )
7983
8084
@@ -706,7 +710,7 @@ def get_observed_types(self) -> dict:
706710 "observed_type" : sort_element_type (observed_type_list ),
707711 }
708712
709- def get_model (self ):
713+ def get_model (self ) -> "BaseModel" :
710714 """Get the PyTorch model.
711715
712716 Returns
Original file line number Diff line number Diff line change @@ -1126,7 +1126,7 @@ def get_model_def_script(self) -> dict:
11261126 model_def_script = script .decode ("utf-8" )
11271127 return json .loads (model_def_script )["model" ]
11281128
1129- def get_model (self ):
1129+ def get_model (self ) -> "tf.Graph" :
11301130 """Get the TensorFlow graph.
11311131
11321132 Returns
You can’t perform that action at this time.
0 commit comments