|
12 | 12 | ) |
13 | 13 |
|
14 | 14 |
|
15 | | -class AtomExcludeMask(AtomExcludeMaskDP): |
| 15 | +class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module): |
| 16 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 17 | + torch.nn.Module.__init__(self) |
| 18 | + AtomExcludeMaskDP.__init__(self, *args, **kwargs) |
| 19 | + |
16 | 20 | def __setattr__(self, name: str, value: Any) -> None: |
17 | | - if name == "type_mask": |
| 21 | + if name == "type_mask" and "_buffers" in self.__dict__: |
18 | 22 | value = None if value is None else torch.as_tensor(value, device=env.DEVICE) |
| 23 | + if name in self._buffers: |
| 24 | + self._buffers[name] = value |
| 25 | + return |
| 26 | + self.register_buffer(name, value) |
| 27 | + return |
19 | 28 | return super().__setattr__(name, value) |
20 | 29 |
|
21 | 30 |
|
22 | | -class PairExcludeMask(PairExcludeMaskDP): |
| 31 | +class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): |
| 32 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 33 | + torch.nn.Module.__init__(self) |
| 34 | + PairExcludeMaskDP.__init__(self, *args, **kwargs) |
| 35 | + |
23 | 36 | def __setattr__(self, name: str, value: Any) -> None: |
24 | | - if name == "type_mask": |
| 37 | + if name == "type_mask" and "_buffers" in self.__dict__: |
25 | 38 | value = None if value is None else torch.as_tensor(value, device=env.DEVICE) |
| 39 | + if name in self._buffers: |
| 40 | + self._buffers[name] = value |
| 41 | + return |
| 42 | + self.register_buffer(name, value) |
| 43 | + return |
26 | 44 | return super().__setattr__(name, value) |
0 commit comments