From a0f41e2ab5a3e0e2f10cab00fd5417076027b2da Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:25:54 +0000 Subject: [PATCH 1/4] Initial plan From 17777ef8a544ab28d732c96f299eb8304cee781b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:32:53 +0000 Subject: [PATCH 2/4] feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator Modified torch_module decorator to automatically generate forward() and forward_lower() methods that delegate to call() and call_lower() if they exist and aren't explicitly defined. Removed boilerplate forward methods from all descriptor and fitting classes. Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pt_expt/common.py | 16 ++++++++++++ deepmd/pt_expt/descriptor/se_e2_a.py | 23 +---------------- deepmd/pt_expt/descriptor/se_r.py | 23 +---------------- deepmd/pt_expt/descriptor/se_t.py | 23 +---------------- deepmd/pt_expt/descriptor/se_t_tebd.py | 23 +---------------- deepmd/pt_expt/descriptor/se_t_tebd_block.py | 26 +------------------- deepmd/pt_expt/fitting/ener_fitting.py | 22 +---------------- deepmd/pt_expt/fitting/invar_fitting.py | 22 +---------------- deepmd/pt_expt/utils/network.py | 3 --- 9 files changed, 23 insertions(+), 158 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index c7375119e2..b47484f827 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -332,6 +332,22 @@ def __setattr__(self, name: str, value: Any) -> None: if not handled: super().__setattr__(name, value) + # Auto-generate forward -> call redirect if not explicitly defined + if hasattr(module, "call") and "forward" not in module.__dict__: + + def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001 + return self.call(*args, **kwargs) + + TorchModule.forward = forward + + # Auto-generate forward_lower -> call_lower redirect if not explicitly defined + if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__: + + def forward_lower(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001 + return self.call_lower(*args, **kwargs) + + TorchModule.forward_lower = forward_lower + return TorchModule diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 894f175764..a140a5d346 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.common import ( torch_module, @@ -15,23 +13,4 @@ @BaseDescriptor.register("se_a") @torch_module class DescrptSeA(DescrptSeADP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index ff5111a77a..a449614f47 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.common import ( torch_module, @@ -15,23 +13,4 @@ @BaseDescriptor.register("se_r") @torch_module class DescrptSeR(DescrptSeRDP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 9879a73a81..de76b4ecf7 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP from deepmd.pt_expt.common import ( torch_module, @@ -16,23 +14,4 @@ @BaseDescriptor.register("se_a_3be") @torch_module class DescrptSeT(DescrptSeTDP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 65a6659972..995dc24c3b 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP from deepmd.pt_expt.common import ( torch_module, @@ -14,23 +12,4 @@ @BaseDescriptor.register("se_e3_tebd") @torch_module class DescrptSeTTebd(DescrptSeTTebdDP): - def forward( - self, - extended_coord: torch.Tensor, - extended_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - descrpt, rot_mat, g2, h2, sw = self.call( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - ) - return descrpt, rot_mat, g2, h2, sw + pass diff --git a/deepmd/pt_expt/descriptor/se_t_tebd_block.py b/deepmd/pt_expt/descriptor/se_t_tebd_block.py index 5676286cbe..07d6fc23fc 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd_block.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd_block.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.descriptor.se_t_tebd import ( DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, ) @@ -13,29 +11,7 @@ @torch_module class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): - def forward( - self, - nlist: torch.Tensor, - coord_ext: torch.Tensor, - atype_ext: torch.Tensor, - atype_embd_ext: torch.Tensor | None = None, - mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, - ) -> tuple[ - torch.Tensor, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - ]: - return self.call( - nlist, - coord_ext, - atype_ext, - atype_embd_ext=atype_embd_ext, - mapping=mapping, - type_embedding=type_embedding, - ) + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index 1c91f09526..f9779e44af 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( register_dpmodel_mapping, @@ -21,25 +19,7 @@ class EnergyFittingNet(EnergyFittingNetDP): This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. """ - def forward( - self, - descriptor: torch.Tensor, - atype: torch.Tensor, - gr: torch.Tensor | None = None, - g2: torch.Tensor | None = None, - h2: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - return self.call( - descriptor, - atype, - gr=gr, - g2=g2, - h2=h2, - fparam=fparam, - aparam=aparam, - ) + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index 640afe232e..ab908ebe0d 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( register_dpmodel_mapping, @@ -15,25 +13,7 @@ @BaseFitting.register("invar") @torch_module class InvarFitting(InvarFittingDP): - def forward( - self, - descriptor: torch.Tensor, - atype: torch.Tensor, - gr: torch.Tensor | None = None, - g2: torch.Tensor | None = None, - h2: torch.Tensor | None = None, - fparam: torch.Tensor | None = None, - aparam: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - return self.call( - descriptor, - atype, - gr=gr, - g2=g2, - h2=h2, - fparam=fparam, - aparam=aparam, - ) + pass register_dpmodel_mapping( diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 929907c2f3..e3f493de9a 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -194,9 +194,6 @@ def __init__(self, layers: list[dict] | None = None) -> None: super().__init__(layers) self.layers = torch.nn.ModuleList(self.layers) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.call(x) - class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: From 2153852bafba7c01d19cc362a4c119875c654652 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Feb 2026 19:36:05 +0000 Subject: [PATCH 3/4] test(pt_expt): add tests for torch_module auto-generation Added comprehensive tests to verify torch_module decorator correctly auto-generates forward() and forward_lower() methods, respects explicit overrides, and handles edge cases properly. Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- source/tests/pt_expt/utils/test_common.py | 106 ++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py index ee8a7ca324..37e6797c50 100644 --- a/source/tests/pt_expt/utils/test_common.py +++ b/source/tests/pt_expt/utils/test_common.py @@ -3,8 +3,12 @@ import numpy as np import torch +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.pt_expt.common import ( to_torch_array, + torch_module, ) from deepmd.pt_expt.utils import ( env, @@ -21,3 +25,105 @@ def test_to_torch_array_moves_device() -> None: output_tensor = to_torch_array(input_tensor) assert torch.is_tensor(output_tensor) assert output_tensor.device == env.DEVICE + + +def test_torch_module_auto_generates_forward() -> None: + """Test that torch_module auto-generates forward() from call().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + @torch_module + class MockModule(MockNativeOP): + pass + + module = MockModule() + input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + output = module.forward(input_tensor) + expected = input_tensor * 2 + assert torch.allclose(output, expected) + + +def test_torch_module_auto_generates_forward_lower() -> None: + """Test that torch_module auto-generates forward_lower() from call_lower().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x + + def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x + y + + @torch_module + class MockModule(MockNativeOP): + pass + + module = MockModule() + input_x = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + input_y = torch.tensor([4.0, 5.0, 6.0], device=torch.device("cpu")) + output = module.forward_lower(input_x, input_y) + expected = input_x + input_y + assert torch.allclose(output, expected) + + +def test_torch_module_respects_explicit_forward() -> None: + """Test that torch_module doesn't override an explicitly defined forward().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + @torch_module + class MockModule(MockNativeOP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + # This should override the auto-generated forward + return x * 3 + + module = MockModule() + input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + output = module.forward(input_tensor) + expected = input_tensor * 3 # Should use the explicit forward, not call() + assert torch.allclose(output, expected) + + +def test_torch_module_respects_explicit_forward_lower() -> None: + """Test that torch_module doesn't override an explicitly defined forward_lower().""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x + + def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: + return x + y + + @torch_module + class MockModule(MockNativeOP): + def forward_lower(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # This should override the auto-generated forward_lower + return x - y + + module = MockModule() + input_x = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) + input_y = torch.tensor([4.0, 5.0, 6.0], device=torch.device("cpu")) + output = module.forward_lower(input_x, input_y) + expected = input_x - input_y # Should use the explicit forward_lower + assert torch.allclose(output, expected) + + +def test_torch_module_handles_no_call_lower_method() -> None: + """Test that torch_module works even when there's no call_lower() method.""" + + class MockNativeOP(NativeOP): + def call(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + @torch_module + class MockModule(MockNativeOP): + pass + + module = MockModule() + # Should have forward method since there's a call method + assert hasattr(module, "forward") + # Should not have forward_lower method since there's no call_lower method + assert not hasattr(module, "forward_lower") From 6703f13e276f45f798f0b18e23deee4bafdfb3d6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 16 Feb 2026 21:58:01 +0000 Subject: [PATCH 4/4] docs(pt_expt): improve torch_module docstring and update tests - Enhanced torch_module decorator docstring to document auto-generation behavior - Updated tests to use module(...) instead of module.forward(...) to test the full invocation path through torch.nn.Module.__call__ - This ensures tests cover hooks, tracing, and export behavior properly Addresses review feedback from PR #5246 Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pt_expt/common.py | 18 ++++++++++++++++-- source/tests/pt_expt/utils/test_common.py | 4 ++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index b47484f827..4d8b371b11 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -300,6 +300,20 @@ def torch_module( ) -> type[torch.nn.Module]: """Convert a NativeOP to a torch.nn.Module. + This decorator wraps a NativeOP class to make it a PyTorch module, handling + initialization, attribute setting, and method delegation automatically. + + **Auto-generated methods:** + + - If the wrapped class has a ``call()`` method but does not explicitly define + ``forward()``, a ``forward()`` method will be auto-generated that delegates + to ``call()``. + - If the wrapped class has a ``call_lower()`` method but does not explicitly + define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated + that delegates to ``call_lower()``. + - Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class + are always respected and will not be overridden. + Parameters ---------- module : type[NativeOP] @@ -308,13 +322,13 @@ def torch_module( Returns ------- type[torch.nn.Module] - The torch.nn.Module. + The torch.nn.Module with auto-generated delegation methods if applicable. Examples -------- >>> @torch_module ... class MyModule(NativeOP): - ... pass + ... pass # forward() auto-generated from call() if it exists """ @wraps(module, updated=()) diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py index 37e6797c50..57ee153e54 100644 --- a/source/tests/pt_expt/utils/test_common.py +++ b/source/tests/pt_expt/utils/test_common.py @@ -40,7 +40,7 @@ class MockModule(MockNativeOP): module = MockModule() input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) - output = module.forward(input_tensor) + output = module(input_tensor) expected = input_tensor * 2 assert torch.allclose(output, expected) @@ -82,7 +82,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: module = MockModule() input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu")) - output = module.forward(input_tensor) + output = module(input_tensor) expected = input_tensor * 3 # Should use the explicit forward, not call() assert torch.allclose(output, expected)