File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -189,18 +189,19 @@ def _try_convert_list(name: str, value: list) -> torch.nn.Module | None:
189189 )
190190 converted .append (c )
191191 return torch .nn .ModuleList (converted )
192- # List of numpy arrays → ParameterList (non-trainable)
192+ # List of numpy arrays → ParameterList
193193 if all (isinstance (v , np .ndarray ) for v in value ):
194194 from deepmd .pt_expt .utils import env # deferred - avoids circular import
195195
196- return torch .nn .ParameterList (
197- [
196+ params = []
197+ for v in value :
198+ t = torch .as_tensor (v , device = env .DEVICE )
199+ params .append (
198200 torch .nn .Parameter (
199- torch . as_tensor ( v , device = env . DEVICE ), requires_grad = False
201+ t , requires_grad = t . is_floating_point () or t . is_complex ()
200202 )
201- for v in value
202- ]
203- )
203+ )
204+ return torch .nn .ParameterList (params )
204205 return None
205206
206207
You can’t perform that action at this time.
0 commit comments