File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55 Optional ,
66)
77
8- import numpy as np
98import torch
109
1110from deepmd .dpmodel import (
@@ -366,7 +365,7 @@ def get_dim_fparam(self) -> int:
366365 def has_default_fparam (self ) -> bool :
367366 return self .fitting_net .has_default_fparam ()
368367
369- def get_default_fparam (self ) -> Optional [np . array ]:
368+ def get_default_fparam (self ) -> Optional [torch . Tensor ]:
370369 return self .fitting_net .get_default_fparam ()
371370
372371 def get_dim_aparam (self ) -> int :
Original file line number Diff line number Diff line change 33 Optional ,
44)
55
6- import numpy as np
76import torch
87
98from deepmd .dpmodel import (
@@ -531,8 +530,7 @@ def get_dim_fparam(self) -> int:
531530 def has_default_fparam (self ) -> bool :
532531 return self .atomic_model .has_default_fparam ()
533532
534- @torch .jit .export
535- def get_default_fparam (self ) -> Optional [np .array ]:
533+ def get_default_fparam (self ) -> Optional [torch .Tensor ]:
536534 return self .atomic_model .get_default_fparam ()
537535
538536 @torch .jit .export
Original file line number Diff line number Diff line change @@ -619,8 +619,8 @@ def get_dim_fparam(self) -> int:
619619 def has_default_fparam (self ) -> bool :
620620 return self .default_fparam is not None
621621
622- def get_default_fparam (self ) -> Optional [np . array ]:
623- return self .default_fparam_tensor . cpu (). numpy ()
622+ def get_default_fparam (self ) -> Optional [torch . Tensor ]:
623+ return self .default_fparam_tensor
624624
625625 def get_dim_aparam (self ) -> int :
626626 """Get the number (dimension) of atomic parameters of this atomic model."""
Original file line number Diff line number Diff line change @@ -1305,7 +1305,9 @@ def get_additional_data_requirement(_model):
13051305 additional_data_requirement = []
13061306 if _model .get_dim_fparam () > 0 :
13071307 _fparam_default = (
1308- _model .get_default_fparam () if _model .has_default_fparam () else 0.0
1308+ _model .get_default_fparam ().cpu ().numpy ()
1309+ if _model .has_default_fparam ()
1310+ else 0.0
13091311 )
13101312 fparam_requirement_items = [
13111313 DataRequirementItem (
You can’t perform that action at this time.
0 commit comments