Skip to content

Commit 4457061

Browse files
njzjzCopilotgemini-code-assist[bot]
authored
refact(pt_expt): add decorator to simplify the module (#5213)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a public adapter to expose DP classes as PyTorch modules. * **Refactor** * Switched multiple descriptor, network and utility classes to a decorator-based PyTorch integration for consistent parameter/buffer handling. * **Breaking Changes** * Several descriptor forward signatures expanded to accept extended topology and embedding inputs (adjust call sites accordingly). <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent cd67bbe commit 4457061

5 files changed

Lines changed: 63 additions & 74 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from collections.abc import (
1818
Callable,
1919
)
20+
from functools import (
21+
wraps,
22+
)
2023
from typing import (
2124
Any,
2225
overload,
@@ -292,6 +295,46 @@ def to_torch_array(array: Any) -> torch.Tensor | None:
292295
return torch.as_tensor(array, device=env.DEVICE)
293296

294297

298+
def torch_module(
299+
module: type[NativeOP],
300+
) -> type[torch.nn.Module]:
301+
"""Convert a NativeOP to a torch.nn.Module.
302+
303+
Parameters
304+
----------
305+
module : type[NativeOP]
306+
The NativeOP to convert.
307+
308+
Returns
309+
-------
310+
type[torch.nn.Module]
311+
The torch.nn.Module.
312+
313+
Examples
314+
--------
315+
>>> @torch_module
316+
... class MyModule(NativeOP):
317+
... pass
318+
"""
319+
320+
@wraps(module, updated=())
321+
class TorchModule(module, torch.nn.Module):
322+
def __init__(self, *args: Any, **kwargs: Any) -> None:
323+
torch.nn.Module.__init__(self)
324+
module.__init__(self, *args, **kwargs)
325+
326+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
327+
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
328+
return torch.nn.Module.__call__(self, *args, **kwargs)
329+
330+
def __setattr__(self, name: str, value: Any) -> None:
331+
handled, value = dpmodel_setattr(self, name, value)
332+
if not handled:
333+
super().__setattr__(name, value)
334+
335+
return TorchModule
336+
337+
295338
# Import utils to trigger dpmodel→pt_expt converter registrations
296339
# This must happen after the functions above are defined to avoid circular imports
297340
def _ensure_registrations() -> None:

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
import torch
74

85
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
96
from deepmd.pt_expt.common import (
10-
dpmodel_setattr,
7+
torch_module,
118
)
129
from deepmd.pt_expt.descriptor.base_descriptor import (
1310
BaseDescriptor,
@@ -16,20 +13,8 @@
1613

1714
@BaseDescriptor.register("se_e2_a_expt")
1815
@BaseDescriptor.register("se_a_expt")
19-
class DescrptSeA(DescrptSeADP, torch.nn.Module):
20-
def __init__(self, *args: Any, **kwargs: Any) -> None:
21-
torch.nn.Module.__init__(self)
22-
DescrptSeADP.__init__(self, *args, **kwargs)
23-
24-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
25-
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
26-
return torch.nn.Module.__call__(self, *args, **kwargs)
27-
28-
def __setattr__(self, name: str, value: Any) -> None:
29-
handled, value = dpmodel_setattr(self, name, value)
30-
if not handled:
31-
super().__setattr__(name, value)
32-
16+
@torch_module
17+
class DescrptSeA(DescrptSeADP):
3318
def forward(
3419
self,
3520
extended_coord: torch.Tensor,

deepmd/pt_expt/descriptor/se_r.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
import torch
74

85
from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
96
from deepmd.pt_expt.common import (
10-
dpmodel_setattr,
7+
torch_module,
118
)
129
from deepmd.pt_expt.descriptor.base_descriptor import (
1310
BaseDescriptor,
@@ -16,20 +13,8 @@
1613

1714
@BaseDescriptor.register("se_e2_r_expt")
1815
@BaseDescriptor.register("se_r_expt")
19-
class DescrptSeR(DescrptSeRDP, torch.nn.Module):
20-
def __init__(self, *args: Any, **kwargs: Any) -> None:
21-
torch.nn.Module.__init__(self)
22-
DescrptSeRDP.__init__(self, *args, **kwargs)
23-
24-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
25-
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
26-
return torch.nn.Module.__call__(self, *args, **kwargs)
27-
28-
def __setattr__(self, name: str, value: Any) -> None:
29-
handled, value = dpmodel_setattr(self, name, value)
30-
if not handled:
31-
super().__setattr__(name, value)
32-
16+
@torch_module
17+
class DescrptSeR(DescrptSeRDP):
3318
def forward(
3419
self,
3520
extended_coord: torch.Tensor,

deepmd/pt_expt/utils/exclude_mask.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

6-
import torch
73

84
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
95
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
106
from deepmd.pt_expt.common import (
11-
dpmodel_setattr,
127
register_dpmodel_mapping,
8+
torch_module,
139
)
1410

1511

16-
class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module):
17-
def __init__(self, *args: Any, **kwargs: Any) -> None:
18-
torch.nn.Module.__init__(self)
19-
AtomExcludeMaskDP.__init__(self, *args, **kwargs)
20-
21-
def __setattr__(self, name: str, value: Any) -> None:
22-
handled, value = dpmodel_setattr(self, name, value)
23-
if not handled:
24-
super().__setattr__(name, value)
12+
@torch_module
13+
class AtomExcludeMask(AtomExcludeMaskDP):
14+
pass
2515

2616

2717
register_dpmodel_mapping(
@@ -30,15 +20,9 @@ def __setattr__(self, name: str, value: Any) -> None:
3020
)
3121

3222

33-
class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module):
34-
def __init__(self, *args: Any, **kwargs: Any) -> None:
35-
torch.nn.Module.__init__(self)
36-
PairExcludeMaskDP.__init__(self, *args, **kwargs)
37-
38-
def __setattr__(self, name: str, value: Any) -> None:
39-
handled, value = dpmodel_setattr(self, name, value)
40-
if not handled:
41-
super().__setattr__(name, value)
23+
@torch_module
24+
class PairExcludeMask(PairExcludeMaskDP):
25+
pass
4226

4327

4428
register_dpmodel_mapping(

deepmd/pt_expt/utils/network.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from deepmd.pt_expt.common import (
2222
register_dpmodel_mapping,
2323
to_torch_array,
24+
torch_module,
2425
)
2526

2627

@@ -37,14 +38,8 @@ def __array__(self, dtype: Any | None = None) -> np.ndarray:
3738
return arr.astype(dtype)
3839

3940

40-
class NativeLayer(NativeLayerDP, torch.nn.Module):
41-
def __init__(self, *args: Any, **kwargs: Any) -> None:
42-
torch.nn.Module.__init__(self)
43-
NativeLayerDP.__init__(self, *args, **kwargs)
44-
45-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
46-
return torch.nn.Module.__call__(self, *args, **kwargs)
47-
41+
@torch_module
42+
class NativeLayer(NativeLayerDP):
4843
def __setattr__(self, name: str, value: Any) -> None:
4944
if name in {"w", "b", "idt"} and "_parameters" in self.__dict__:
5045
val = to_torch_array(value)
@@ -78,15 +73,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7873
return self.call(x)
7974

8075

81-
class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module):
76+
@torch_module
77+
class NativeNet(make_multilayer_network(NativeLayer, NativeOP)):
8278
def __init__(self, layers: list[dict] | None = None) -> None:
83-
torch.nn.Module.__init__(self)
8479
super().__init__(layers)
8580
self.layers = torch.nn.ModuleList(self.layers)
8681

87-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
88-
return torch.nn.Module.__call__(self, *args, **kwargs)
89-
9082
def forward(self, x: torch.Tensor) -> torch.Tensor:
9183
return self.call(x)
9284

@@ -118,15 +110,15 @@ class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
118110
pass
119111

120112

121-
class NetworkCollection(NetworkCollectionDP, torch.nn.Module):
113+
@torch_module
114+
class NetworkCollection(NetworkCollectionDP):
122115
NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = {
123116
"network": NativeNet,
124117
"embedding_network": EmbeddingNet,
125118
"fitting_network": FittingNet,
126119
}
127120

128121
def __init__(self, *args: Any, **kwargs: Any) -> None:
129-
torch.nn.Module.__init__(self)
130122
self._module_networks = torch.nn.ModuleDict()
131123
super().__init__(*args, **kwargs)
132124

0 commit comments

Comments
 (0)