1919 make_fitting_network ,
2020 make_multilayer_network ,
2121)
22- from deepmd .pt_expt .utils import (
23- env ,
22+ from deepmd .pt_expt .common import (
23+ to_torch_array ,
2424)
2525
2626torch = importlib .import_module ("torch" )
2727
2828
29- def _to_torch_array (value : Any ) -> torch .Tensor | None :
30- if value is None :
31- return None
32- if torch .is_tensor (value ):
33- return value
34- return torch .as_tensor (value , device = env .DEVICE )
35-
36-
3729class TorchArrayParam (torch .nn .Parameter ):
3830 def __new__ (cls , data : Any = None , requires_grad : bool = True ) -> Self :
3931 return torch .nn .Parameter .__new__ (cls , data , requires_grad )
@@ -52,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
5244 for name in ("w" , "b" , "idt" ):
5345 if name in self ._parameters or name in self ._buffers :
5446 continue
55- val = _to_torch_array (getattr (self , name ))
47+ val = to_torch_array (getattr (self , name ))
5648 if val is None :
5749 continue
5850 if self .trainable :
@@ -66,7 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6658
6759 def __setattr__ (self , name : str , value : Any ) -> None :
6860 if name in {"w" , "b" , "idt" } and "_parameters" in self .__dict__ :
69- val = _to_torch_array (value )
61+ val = to_torch_array (value )
7062 if val is None :
7163 return super ().__setattr__ (name , None )
7264 if getattr (self , "trainable" , False ):
0 commit comments