Skip to content

Commit aeef15a

Browse files
author
Han Wang
committed
fix(pt-expt): clear params on None
1 parent d8b2cf4 commit aeef15a

2 files changed

Lines changed: 27 additions & 0 deletions

File tree

deepmd/pt_expt/utils/network.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def __setattr__(self, name: str, value: Any) -> None:
4848
if name in {"w", "b", "idt"} and "_parameters" in self.__dict__:
4949
val = to_torch_array(value)
5050
if val is None:
51+
if name in self._parameters:
52+
self._parameters[name] = None
53+
return
54+
if name in self._buffers:
55+
self._buffers[name] = None
56+
return
5157
return super().__setattr__(name, None)
5258
if getattr(self, "trainable", False):
5359
param = (
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
from deepmd.pt_expt.utils.network import (
4+
NativeLayer,
5+
)
6+
7+
8+
def test_native_layer_clears_parameter_on_none() -> None:
9+
layer = NativeLayer(2, 3, trainable=True)
10+
assert layer.w is not None
11+
layer.w = None
12+
assert layer.w is None
13+
assert layer._parameters.get("w") is None
14+
15+
16+
def test_native_layer_clears_buffer_on_none() -> None:
17+
layer = NativeLayer(2, 3, trainable=False)
18+
assert layer.w is not None
19+
layer.w = None
20+
assert layer.w is None
21+
assert layer._buffers.get("w") is None

0 commit comments

Comments
 (0)