Skip to content

Commit d8b2cf4

Browse files
author
Han Wang
committed
make exclusion mask modules
1 parent eedcbaf commit d8b2cf4

2 files changed

Lines changed: 30 additions & 4 deletions

File tree

deepmd/pt_expt/utils/exclude_mask.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,33 @@
1212
)
1313

1414

15-
class AtomExcludeMask(AtomExcludeMaskDP):
15+
class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module):
16+
def __init__(self, *args: Any, **kwargs: Any) -> None:
17+
torch.nn.Module.__init__(self)
18+
AtomExcludeMaskDP.__init__(self, *args, **kwargs)
19+
1620
def __setattr__(self, name: str, value: Any) -> None:
17-
if name == "type_mask":
21+
if name == "type_mask" and "_buffers" in self.__dict__:
1822
value = None if value is None else torch.as_tensor(value, device=env.DEVICE)
23+
if name in self._buffers:
24+
self._buffers[name] = value
25+
return
26+
self.register_buffer(name, value)
27+
return
1928
return super().__setattr__(name, value)
2029

2130

22-
class PairExcludeMask(PairExcludeMaskDP):
31+
class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module):
32+
def __init__(self, *args: Any, **kwargs: Any) -> None:
33+
torch.nn.Module.__init__(self)
34+
PairExcludeMaskDP.__init__(self, *args, **kwargs)
35+
2336
def __setattr__(self, name: str, value: Any) -> None:
24-
if name == "type_mask":
37+
if name == "type_mask" and "_buffers" in self.__dict__:
2538
value = None if value is None else torch.as_tensor(value, device=env.DEVICE)
39+
if name in self._buffers:
40+
self._buffers[name] = value
41+
return
42+
self.register_buffer(name, value)
43+
return
2644
return super().__setattr__(name, value)

source/tests/pt_expt/utils/test_exclusion_mask.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def test_build_type_exclude_mask(self) -> None:
3939
mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE))
4040
np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask)
4141

42+
def test_type_mask_is_buffer(self) -> None:
43+
des = AtomExcludeMask(3, exclude_types=[0])
44+
assert "type_mask" in des.state_dict()
45+
4246

4347
class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist):
4448
def setUp(self) -> None:
@@ -62,3 +66,7 @@ def test_build_type_exclude_mask(self) -> None:
6266
torch.as_tensor(self.atype_ext, device=env.DEVICE),
6367
)
6468
np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask)
69+
70+
def test_type_mask_is_buffer(self) -> None:
71+
des = PairExcludeMask(self.nt, exclude_types=[[0, 1]])
72+
assert "type_mask" in des.state_dict()

0 commit comments

Comments
 (0)