Skip to content

Commit 115ec93

Browse files
author
Han Wang
committed
fix issue of require grad
1 parent 3a286e5 commit 115ec93

1 file changed

Lines changed: 8 additions & 7 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,19 @@ def _try_convert_list(name: str, value: list) -> torch.nn.Module | None:
189189
)
190190
converted.append(c)
191191
return torch.nn.ModuleList(converted)
192-
# List of numpy arrays → ParameterList (non-trainable)
192+
# List of numpy arrays → ParameterList
193193
if all(isinstance(v, np.ndarray) for v in value):
194194
from deepmd.pt_expt.utils import env # deferred - avoids circular import
195195

196-
return torch.nn.ParameterList(
197-
[
196+
params = []
197+
for v in value:
198+
t = torch.as_tensor(v, device=env.DEVICE)
199+
params.append(
198200
torch.nn.Parameter(
199-
torch.as_tensor(v, device=env.DEVICE), requires_grad=False
201+
t, requires_grad=t.is_floating_point() or t.is_complex()
200202
)
201-
for v in value
202-
]
203-
)
203+
)
204+
return torch.nn.ParameterList(params)
204205
return None
205206

206207

0 commit comments

Comments
 (0)