Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions deepmd/pt_expt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,20 @@ def torch_module(
) -> type[torch.nn.Module]:
"""Convert a NativeOP to a torch.nn.Module.

This decorator wraps a NativeOP class to make it a PyTorch module, handling
initialization, attribute setting, and method delegation automatically.

**Auto-generated methods:**

- If the wrapped class has a ``call()`` method but does not explicitly define
``forward()``, a ``forward()`` method will be auto-generated that delegates
to ``call()``.
- If the wrapped class has a ``call_lower()`` method but does not explicitly
define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated
that delegates to ``call_lower()``.
- Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class
are always respected and will not be overridden.

Parameters
----------
module : type[NativeOP]
Expand All @@ -308,13 +322,13 @@ def torch_module(
Returns
-------
type[torch.nn.Module]
The torch.nn.Module.
The torch.nn.Module with auto-generated delegation methods if applicable.

Examples
--------
>>> @torch_module
... class MyModule(NativeOP):
... pass
... pass # forward() auto-generated from call() if it exists
"""

@wraps(module, updated=())
Expand All @@ -332,6 +346,22 @@ def __setattr__(self, name: str, value: Any) -> None:
if not handled:
super().__setattr__(name, value)

# Auto-generate forward -> call redirect if not explicitly defined
if hasattr(module, "call") and "forward" not in module.__dict__:

def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001
return self.call(*args, **kwargs)

TorchModule.forward = forward

# Auto-generate forward_lower -> call_lower redirect if not explicitly defined
if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__:

def forward_lower(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001
return self.call_lower(*args, **kwargs)

TorchModule.forward_lower = forward_lower
Comment thread
njzjz marked this conversation as resolved.

return TorchModule


Expand Down
23 changes: 1 addition & 22 deletions deepmd/pt_expt/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.pt_expt.common import (
torch_module,
Expand All @@ -15,23 +13,4 @@
@BaseDescriptor.register("se_a")
@torch_module
class DescrptSeA(DescrptSeADP):
def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
descrpt, rot_mat, g2, h2, sw = self.call(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
return descrpt, rot_mat, g2, h2, sw
pass
23 changes: 1 addition & 22 deletions deepmd/pt_expt/descriptor/se_r.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
from deepmd.pt_expt.common import (
torch_module,
Expand All @@ -15,23 +13,4 @@
@BaseDescriptor.register("se_r")
@torch_module
class DescrptSeR(DescrptSeRDP):
def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
descrpt, rot_mat, g2, h2, sw = self.call(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
return descrpt, rot_mat, g2, h2, sw
pass
23 changes: 1 addition & 22 deletions deepmd/pt_expt/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
from deepmd.pt_expt.common import (
torch_module,
Expand All @@ -16,23 +14,4 @@
@BaseDescriptor.register("se_a_3be")
@torch_module
class DescrptSeT(DescrptSeTDP):
def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
descrpt, rot_mat, g2, h2, sw = self.call(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
return descrpt, rot_mat, g2, h2, sw
pass
23 changes: 1 addition & 22 deletions deepmd/pt_expt/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
from deepmd.pt_expt.common import (
torch_module,
Expand All @@ -14,23 +12,4 @@
@BaseDescriptor.register("se_e3_tebd")
@torch_module
class DescrptSeTTebd(DescrptSeTTebdDP):
def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
descrpt, rot_mat, g2, h2, sw = self.call(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
return descrpt, rot_mat, g2, h2, sw
pass
26 changes: 1 addition & 25 deletions deepmd/pt_expt/descriptor/se_t_tebd_block.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.descriptor.se_t_tebd import (
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
)
Expand All @@ -13,29 +11,7 @@

@torch_module
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
def forward(
self,
nlist: torch.Tensor,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
atype_embd_ext: torch.Tensor | None = None,
mapping: torch.Tensor | None = None,
type_embedding: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
return self.call(
nlist,
coord_ext,
atype_ext,
atype_embd_ext=atype_embd_ext,
mapping=mapping,
type_embedding=type_embedding,
)
pass


register_dpmodel_mapping(
Expand Down
22 changes: 1 addition & 21 deletions deepmd/pt_expt/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
Expand All @@ -21,25 +19,7 @@ class EnergyFittingNet(EnergyFittingNetDP):
This inherits from dpmodel EnergyFittingNet to get the correct serialize() method.
"""

def forward(
self,
descriptor: torch.Tensor,
atype: torch.Tensor,
gr: torch.Tensor | None = None,
g2: torch.Tensor | None = None,
h2: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
return self.call(
descriptor,
atype,
gr=gr,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
pass


register_dpmodel_mapping(
Expand Down
22 changes: 1 addition & 21 deletions deepmd/pt_expt/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
Expand All @@ -15,25 +13,7 @@
@BaseFitting.register("invar")
@torch_module
class InvarFitting(InvarFittingDP):
def forward(
self,
descriptor: torch.Tensor,
atype: torch.Tensor,
gr: torch.Tensor | None = None,
g2: torch.Tensor | None = None,
h2: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
return self.call(
descriptor,
atype,
gr=gr,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
pass


register_dpmodel_mapping(
Expand Down
3 changes: 0 additions & 3 deletions deepmd/pt_expt/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,6 @@ def __init__(self, layers: list[dict] | None = None) -> None:
super().__init__(layers)
self.layers = torch.nn.ModuleList(self.layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)


class EmbeddingNet(EmbeddingNetDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down
Loading
Loading