Skip to content

Commit fe937f6

Browse files
author
Han Wang
committed
merge master
2 parents a3054f2 + 4ddc37d commit fe937f6

15 files changed

Lines changed: 150 additions & 185 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,20 @@ def torch_module(
385385
) -> type[torch.nn.Module]:
386386
"""Convert a NativeOP to a torch.nn.Module.
387387
388+
This decorator wraps a NativeOP class to make it a PyTorch module, handling
389+
initialization, attribute setting, and method delegation automatically.
390+
391+
**Auto-generated methods:**
392+
393+
- If the wrapped class has a ``call()`` method but does not explicitly define
394+
``forward()``, a ``forward()`` method will be auto-generated that delegates
395+
to ``call()``.
396+
- If the wrapped class has a ``call_lower()`` method but does not explicitly
397+
define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated
398+
that delegates to ``call_lower()``.
399+
- Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class
400+
are always respected and will not be overridden.
401+
388402
Parameters
389403
----------
390404
module : type[NativeOP]
@@ -393,13 +407,13 @@ def torch_module(
393407
Returns
394408
-------
395409
type[torch.nn.Module]
396-
The torch.nn.Module.
410+
The torch.nn.Module with auto-generated delegation methods if applicable.
397411
398412
Examples
399413
--------
400414
>>> @torch_module
401415
... class MyModule(NativeOP):
402-
... pass
416+
... pass # forward() auto-generated from call() if it exists
403417
"""
404418

405419
@wraps(module, updated=())
@@ -426,6 +440,22 @@ def __setattr__(self, name: str, value: Any) -> None:
426440
if not handled:
427441
super().__setattr__(name, value)
428442

443+
# Auto-generate forward -> call redirect if not explicitly defined
444+
if hasattr(module, "call") and "forward" not in module.__dict__:
445+
446+
def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001
447+
return self.call(*args, **kwargs)
448+
449+
TorchModule.forward = forward
450+
451+
# Auto-generate forward_lower -> call_lower redirect if not explicitly defined
452+
if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__:
453+
454+
def forward_lower(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001
455+
return self.call_lower(*args, **kwargs)
456+
457+
TorchModule.forward_lower = forward_lower
458+
429459
return TorchModule
430460

431461

deepmd/pt_expt/descriptor/dpa1.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP
74
from deepmd.pt_expt.common import (
@@ -16,5 +13,4 @@
1613
@BaseDescriptor.register("dpa1")
1714
@torch_module
1815
class DescrptDPA1(DescrptDPA1DP):
19-
def forward(self, *args: Any, **kwargs: Any) -> Any:
20-
return self.call(*args, **kwargs)
16+
pass

deepmd/pt_expt/descriptor/dpa2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP
74
from deepmd.pt_expt.common import (
@@ -15,5 +12,4 @@
1512
@BaseDescriptor.register("dpa2")
1613
@torch_module
1714
class DescrptDPA2(DescrptDPA2DP):
18-
def forward(self, *args: Any, **kwargs: Any) -> Any:
19-
return self.call(*args, **kwargs)
15+
pass

deepmd/pt_expt/descriptor/dpa3.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP
74
from deepmd.pt_expt.common import (
@@ -15,5 +12,4 @@
1512
@BaseDescriptor.register("dpa3")
1613
@torch_module
1714
class DescrptDPA3(DescrptDPA3DP):
18-
def forward(self, *args: Any, **kwargs: Any) -> Any:
19-
return self.call(*args, **kwargs)
15+
pass

deepmd/pt_expt/descriptor/hybrid.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
74
from deepmd.pt_expt.common import (
@@ -15,5 +12,4 @@
1512
@BaseDescriptor.register("hybrid")
1613
@torch_module
1714
class DescrptHybrid(DescrptHybridDP):
18-
def forward(self, *args: Any, **kwargs: Any) -> Any:
19-
return self.call(*args, **kwargs)
15+
pass

deepmd/pt_expt/descriptor/se_atten_v2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Any,
4-
)
52

63
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP
74
from deepmd.pt_expt.common import (
@@ -15,5 +12,4 @@
1512
@BaseDescriptor.register("se_atten_v2")
1613
@torch_module
1714
class DescrptSeAttenV2(DescrptSeAttenV2DP):
18-
def forward(self, *args: Any, **kwargs: Any) -> Any:
19-
return self.call(*args, **kwargs)
15+
pass
Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -15,23 +13,4 @@
1513
@BaseDescriptor.register("se_a")
1614
@torch_module
1715
class DescrptSeA(DescrptSeADP):
18-
def forward(
19-
self,
20-
extended_coord: torch.Tensor,
21-
extended_atype: torch.Tensor,
22-
nlist: torch.Tensor,
23-
mapping: torch.Tensor | None = None,
24-
) -> tuple[
25-
torch.Tensor,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
]:
31-
descrpt, rot_mat, g2, h2, sw = self.call(
32-
extended_coord,
33-
extended_atype,
34-
nlist,
35-
mapping=mapping,
36-
)
37-
return descrpt, rot_mat, g2, h2, sw
16+
pass

deepmd/pt_expt/descriptor/se_r.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -15,23 +13,4 @@
1513
@BaseDescriptor.register("se_r")
1614
@torch_module
1715
class DescrptSeR(DescrptSeRDP):
18-
def forward(
19-
self,
20-
extended_coord: torch.Tensor,
21-
extended_atype: torch.Tensor,
22-
nlist: torch.Tensor,
23-
mapping: torch.Tensor | None = None,
24-
) -> tuple[
25-
torch.Tensor,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
]:
31-
descrpt, rot_mat, g2, h2, sw = self.call(
32-
extended_coord,
33-
extended_atype,
34-
nlist,
35-
mapping=mapping,
36-
)
37-
return descrpt, rot_mat, g2, h2, sw
16+
pass

deepmd/pt_expt/descriptor/se_t.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -16,23 +14,4 @@
1614
@BaseDescriptor.register("se_a_3be")
1715
@torch_module
1816
class DescrptSeT(DescrptSeTDP):
19-
def forward(
20-
self,
21-
extended_coord: torch.Tensor,
22-
extended_atype: torch.Tensor,
23-
nlist: torch.Tensor,
24-
mapping: torch.Tensor | None = None,
25-
) -> tuple[
26-
torch.Tensor,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
torch.Tensor | None,
31-
]:
32-
descrpt, rot_mat, g2, h2, sw = self.call(
33-
extended_coord,
34-
extended_atype,
35-
nlist,
36-
mapping=mapping,
37-
)
38-
return descrpt, rot_mat, g2, h2, sw
17+
pass
Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -14,23 +12,4 @@
1412
@BaseDescriptor.register("se_e3_tebd")
1513
@torch_module
1614
class DescrptSeTTebd(DescrptSeTTebdDP):
17-
def forward(
18-
self,
19-
extended_coord: torch.Tensor,
20-
extended_atype: torch.Tensor,
21-
nlist: torch.Tensor,
22-
mapping: torch.Tensor | None = None,
23-
) -> tuple[
24-
torch.Tensor,
25-
torch.Tensor | None,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
]:
30-
descrpt, rot_mat, g2, h2, sw = self.call(
31-
extended_coord,
32-
extended_atype,
33-
nlist,
34-
mapping=mapping,
35-
)
36-
return descrpt, rot_mat, g2, h2, sw
15+
pass

0 commit comments

Comments
 (0)