diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index c7375119e2..4d8b371b11 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -300,6 +300,20 @@ def torch_module( ) -> type[torch.nn.Module]: """Convert a NativeOP to a torch.nn.Module. + This decorator wraps a NativeOP class to make it a PyTorch module, handling + initialization, attribute setting, and method delegation automatically. + + **Auto-generated methods:** + + - If the wrapped class has a ``call()`` method but does not explicitly define + ``forward()``, a ``forward()`` method will be auto-generated that delegates + to ``call()``. + - If the wrapped class has a ``call_lower()`` method but does not explicitly + define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated + that delegates to ``call_lower()``. + - Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class + are always respected and will not be overridden. + Parameters ---------- module : type[NativeOP] @@ -308,13 +322,13 @@ def torch_module( Returns ------- type[torch.nn.Module] - The torch.nn.Module. + The torch.nn.Module with auto-generated delegation methods if applicable. Examples -------- >>> @torch_module ... class MyModule(NativeOP): - ... pass + ... pass # forward() auto-generated from call() if it exists """ @wraps(module, updated=()) @@ -332,6 +346,22 @@ def __setattr__(self, name: str, value: Any) -> None: if not handled: super().__setattr__(name, value) + # Auto-generate forward -> call redirect if not explicitly defined + if hasattr(module, "call") and "forward" not in module.__dict__: + + def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001 + return self.call(*args, **kwargs) + + TorchModule.forward = forward + + # Auto-generate forward_lower -> call_lower redirect if not explicitly defined + if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__: + + def forward_lower(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001 + return self.call_lower(*args, **kwargs) + + TorchModule.forward_lower = forward_lower + return TorchModule diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 894f175764..a140a5d346 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.common import ( torch_module, @@ -15,23 +13,4 @@ @BaseDescriptor.register("se_a") @torch_module class DescrptSeA(DescrptSeADP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index ff5111a77a..a449614f47 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.common import ( torch_module, @@ -15,23 +13,4 @@ @BaseDescriptor.register("se_r") @torch_module class DescrptSeR(DescrptSeRDP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 9879a73a81..de76b4ecf7 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP from deepmd.pt_expt.common import ( torch_module, @@ -16,23 +14,4 @@ @BaseDescriptor.register("se_a_3be") @torch_module class DescrptSeT(DescrptSeTDP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 65a6659972..995dc24c3b 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP from deepmd.pt_expt.common import ( torch_module, @@ -14,23 +12,4 @@ @BaseDescriptor.register("se_e3_tebd") @torch_module class DescrptSeTTebd(DescrptSeTTebdDP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_t_tebd_block.py b/deepmd/pt_expt/descriptor/se_t_tebd_block.py index 5676286cbe..07d6fc23fc 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd_block.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd_block.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_t_tebd import ( DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, ) @@ -13,29 +11,7 @@ @torch_module class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): - def forward( - self, - nlist: torch.Tensor, - coord_ext: torch.Tensor, - atype_ext: torch.Tensor, - atype_embd_ext: torch.Tensor | None = None, - mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - return self.call( - nlist, - coord_ext, - atype_ext, - atype_embd_ext=atype_embd_ext, - mapping=mapping, - type_embedding=type_embedding, - ) + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index 1c91f09526..f9779e44af 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( register_dpmodel_mapping, @@ -21,25 +19,7 @@ class EnergyFittingNet(EnergyFittingNetDP): This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. """ - def forward( - self, - descriptor: torch.Tensor, - atype: torch.Tensor, - gr: torch.Tensor | None = None, - g2: torch.Tensor | None = None, - h2: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - return self.call( - descriptor, - atype, - gr=gr, - g2=g2, - h2=h2, - fparam=fparam, - aparam=aparam, - ) + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index 640afe232e..ab908ebe0d 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( register_dpmodel_mapping, @@ -15,25 +13,7 @@ @BaseFitting.register("invar") @torch_module class InvarFitting(InvarFittingDP): - def forward( - self, - descriptor: torch.Tensor, - atype: torch.Tensor, - gr: torch.Tensor | None = None, - g2: torch.Tensor | None = None, - h2: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - return self.call( - descriptor, - atype, - gr=gr, - g2=g2, - h2=h2, - fparam=fparam, - aparam=aparam, - ) + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 929907c2f3..e3f493de9a 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -194,9 +194,6 @@ def __init__(self, layers: list[dict] | None = None) -> None: super().__init__(layers) self.layers = torch.nn.ModuleList(self.layers) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.call(x) - class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py index ee8a7ca324..57ee153e54 100644 --- a/source/tests/pt_expt/utils/test_common.py +++ b/source/tests/pt_expt/utils/test_common.py @@ -3,8 +3,12 @@ import numpy as np import torch +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.pt_expt.common import ( to_torch_array, + torch_module, ) from deepmd.pt_expt.utils import ( env, @@ -21,3 +25,105 @@ def test_to_torch_array_moves_device() -> None: output_tensor = to_torch_array(input_tensor) assert torch.is_tensor(output_tensor) assert output_tensor.device == env.DEVICE + + +def test_torch_module_auto_generates_forward() -> None: + """Test that torch_module auto-generates forward() from call().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + @torch_module + class MockModule(MockNativeOP): + pass + + module = MockModule() + input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + output = module(input_tensor) + expected = input_tensor * 2 + assert torch.allclose(output, expected) + + +def test_torch_module_auto_generates_forward_lower() -> None: + """Test that torch_module auto-generates forward_lower() from call_lower().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x + + def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x + y + + @torch_module + class MockModule(MockNativeOP): + pass + + module = MockModule() + input_x = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + input_y = torch.tensor([4.0, 5.0, 6.0], device=torch.device("cpu")) + output = module.forward_lower(input_x, input_y) + expected = input_x + input_y + assert torch.allclose(output, expected) + + +def test_torch_module_respects_explicit_forward() -> None: + """Test that torch_module doesn't override an explicitly defined forward().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + @torch_module + class MockModule(MockNativeOP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # This should override the auto-generated forward + return x * 3 + + module = MockModule() + input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + output = module(input_tensor) + expected = input_tensor * 3 # Should use the explicit forward, not call() + assert torch.allclose(output, expected) + + +def test_torch_module_respects_explicit_forward_lower() -> None: + """Test that torch_module doesn't override an explicitly defined forward_lower().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x + + def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x + y + + @torch_module + class MockModule(MockNativeOP): + def forward_lower(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # This should override the auto-generated forward_lower + return x - y + + module = MockModule() + input_x = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + input_y = torch.tensor([4.0, 5.0, 6.0], device=torch.device("cpu")) + output = module.forward_lower(input_x, input_y) + expected = input_x - input_y # Should use the explicit forward_lower + assert torch.allclose(output, expected) + + +def test_torch_module_handles_no_call_lower_method() -> None: + """Test that torch_module works even when there's no call_lower() method.""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + @torch_module + class MockModule(MockNativeOP): + pass + + module = MockModule() + # Should have forward method since there's a call method + assert hasattr(module, "forward") + # Should not have forward_lower method since there's no call_lower method + assert not hasattr(module, "forward_lower")