Skip to content

Commit 4ddc37d

Browse files
Copilotnjzjz
andauthored
feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator (#5246)
- [x] Understand the current implementation of `torch_module` decorator - [x] Identify all classes using `@torch_module` with manual `forward()` methods - [x] Modify `torch_module` decorator to auto-generate `forward()` from `call()` - [x] Modify `torch_module` decorator to auto-generate `forward_lower()` from `call_lower()` - [x] Remove boilerplate `forward()` methods from descriptor classes (DescrptSeA, DescrptSeR, DescrptSeT, DescrptSeTTebd, DescrptBlockSeTTebd) - [x] Remove boilerplate `forward()` methods from fitting classes (InvarFitting, EnergyFittingNet) - [x] Remove boilerplate `forward()` method from NativeNet - [x] Run linting and formatting - [x] Create tests to verify auto-generated methods work correctly - [x] Run targeted tests to verify changes (63 pt_expt tests pass) - [x] Manually verify forward methods work with real data - [x] Request code review and address feedback - [x] Update docstring to document auto-generation behavior - [x] Update tests to use module(...) invocation path - [x] Ready for final approval <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>pt_expt: auto-generate forward/forward_lower in torch_module decorator to avoid boilerplate</issue_title> > <issue_description>## Summary > > The `torch_module` decorator in `deepmd/pt_expt/common.py` currently handles `__init__`, `__call__`, and `__setattr__` automatically, but each wrapped class still needs to manually define `forward()` (and potentially `forward_lower()`) methods that simply redirect to `call()` (and `call_lower()`). > > ## Current situation > > Every descriptor class in `deepmd/pt_expt/descriptor/` repeats the same boilerplate pattern: > > ```python > @torch_module > class DescrptSeA(DescrptSeADP): > def forward( > self, > extended_coord: torch.Tensor, > extended_atype: torch.Tensor, > nlist: torch.Tensor, > mapping: torch.Tensor | None = None, > ) -> tuple[...]: > descrpt, rot_mat, g2, h2, sw = self.call( > extended_coord, > extended_atype, > nlist, > mapping=mapping, > ) > return descrpt, rot_mat, g2, h2, sw > ``` > > This identical `forward → call` redirect is duplicated across `DescrptSeA`, `DescrptSeR`, `DescrptSeT`, `DescrptSeTTebd`, `DescrptBlockSeTTebd`, and similarly for fitting classes like `InvarFitting` (though `InvarFitting` does not use `torch_module`). > > ## Proposal > > Modify the `torch_module` decorator to **automatically generate `forward` and `forward_lower`** methods that delegate to `call` and `call_lower` respectively, if: > > 1. The wrapped class has a `call` method (from the dpmodel base) but does **not** define its own `forward`. > 2. Similarly for `call_lower` → `forward_lower`. > > This could be implemented by adding something like the following inside `torch_module`: > > ```python > def torch_module(module: type[NativeOP]) -> type[torch.nn.Module]: > @wraps(module, updated=()) > class TorchModule(module, torch.nn.Module): > def __init__(self, *args, **kwargs): > torch.nn.Module.__init__(self) > module.__init__(self, *args, **kwargs) > > def __call__(self, *args, **kwargs): > return torch.nn.Module.__call__(self, *args, **kwargs) > > def __setattr__(self, name, value): > handled, value = dpmodel_setattr(self, name, value) > if not handled: > super().__setattr__(name, value) > > # Auto-generate forward -> call redirect > if hasattr(module, "call") and "forward" not in module.__dict__: > def forward(self, *args, **kwargs): > return self.call(*args, **kwargs) > TorchModule.forward = forward > > # Auto-generate forward_lower -> call_lower redirect > if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__: > def forward_lower(self, *args, **kwargs): > return self.call_lower(*args, **kwargs) > TorchModule.forward_lower = forward_lower > > return TorchModule > ``` > > ## Benefits > > - **Eliminates boilerplate**: Descriptor/fitting wrapper classes become minimal (just the `@torch_module` decorator + `@BaseDescriptor.register` + class declaration). > - **Reduces maintenance burden**: Adding new descriptors to `pt_expt` requires fewer lines and less copy-paste. > - **Consistent behavior**: The `forward ↔ call` contract is enforced in one place rather than scattered across many files. > - **Less error-prone**: No risk of forgetting to update `forward` when `call` signature changes. > > ## Example after this change > > ```python > @BaseDescriptor.register("se_e2_a") > @BaseDescriptor.register("se_a") > @torch_module > class DescrptSeA(DescrptSeADP): > pass # forward() auto-generated from call() > ``` > > --- > *Submitted by OpenClaw on behalf of @njzjz*</issue_description> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #5245 <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 09e40bb commit 4ddc37d

10 files changed

Lines changed: 145 additions & 160 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,20 @@ def torch_module(
300300
) -> type[torch.nn.Module]:
301301
"""Convert a NativeOP to a torch.nn.Module.
302302
303+
This decorator wraps a NativeOP class to make it a PyTorch module, handling
304+
initialization, attribute setting, and method delegation automatically.
305+
306+
**Auto-generated methods:**
307+
308+
- If the wrapped class has a ``call()`` method but does not explicitly define
309+
``forward()``, a ``forward()`` method will be auto-generated that delegates
310+
to ``call()``.
311+
- If the wrapped class has a ``call_lower()`` method but does not explicitly
312+
define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated
313+
that delegates to ``call_lower()``.
314+
- Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class
315+
are always respected and will not be overridden.
316+
303317
Parameters
304318
----------
305319
module : type[NativeOP]
@@ -308,13 +322,13 @@ def torch_module(
308322
Returns
309323
-------
310324
type[torch.nn.Module]
311-
The torch.nn.Module.
325+
The torch.nn.Module with auto-generated delegation methods if applicable.
312326
313327
Examples
314328
--------
315329
>>> @torch_module
316330
... class MyModule(NativeOP):
317-
... pass
331+
... pass # forward() auto-generated from call() if it exists
318332
"""
319333

320334
@wraps(module, updated=())
@@ -332,6 +346,22 @@ def __setattr__(self, name: str, value: Any) -> None:
332346
if not handled:
333347
super().__setattr__(name, value)
334348

349+
# Auto-generate forward -> call redirect if not explicitly defined
350+
if hasattr(module, "call") and "forward" not in module.__dict__:
351+
352+
def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001
353+
return self.call(*args, **kwargs)
354+
355+
TorchModule.forward = forward
356+
357+
# Auto-generate forward_lower -> call_lower redirect if not explicitly defined
358+
if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__:
359+
360+
def forward_lower(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN001
361+
return self.call_lower(*args, **kwargs)
362+
363+
TorchModule.forward_lower = forward_lower
364+
335365
return TorchModule
336366

337367

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -15,23 +13,4 @@
1513
@BaseDescriptor.register("se_a")
1614
@torch_module
1715
class DescrptSeA(DescrptSeADP):
18-
def forward(
19-
self,
20-
extended_coord: torch.Tensor,
21-
extended_atype: torch.Tensor,
22-
nlist: torch.Tensor,
23-
mapping: torch.Tensor | None = None,
24-
) -> tuple[
25-
torch.Tensor,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
]:
31-
descrpt, rot_mat, g2, h2, sw = self.call(
32-
extended_coord,
33-
extended_atype,
34-
nlist,
35-
mapping=mapping,
36-
)
37-
return descrpt, rot_mat, g2, h2, sw
16+
pass

deepmd/pt_expt/descriptor/se_r.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -15,23 +13,4 @@
1513
@BaseDescriptor.register("se_r")
1614
@torch_module
1715
class DescrptSeR(DescrptSeRDP):
18-
def forward(
19-
self,
20-
extended_coord: torch.Tensor,
21-
extended_atype: torch.Tensor,
22-
nlist: torch.Tensor,
23-
mapping: torch.Tensor | None = None,
24-
) -> tuple[
25-
torch.Tensor,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
]:
31-
descrpt, rot_mat, g2, h2, sw = self.call(
32-
extended_coord,
33-
extended_atype,
34-
nlist,
35-
mapping=mapping,
36-
)
37-
return descrpt, rot_mat, g2, h2, sw
16+
pass

deepmd/pt_expt/descriptor/se_t.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -16,23 +14,4 @@
1614
@BaseDescriptor.register("se_a_3be")
1715
@torch_module
1816
class DescrptSeT(DescrptSeTDP):
19-
def forward(
20-
self,
21-
extended_coord: torch.Tensor,
22-
extended_atype: torch.Tensor,
23-
nlist: torch.Tensor,
24-
mapping: torch.Tensor | None = None,
25-
) -> tuple[
26-
torch.Tensor,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
torch.Tensor | None,
31-
]:
32-
descrpt, rot_mat, g2, h2, sw = self.call(
33-
extended_coord,
34-
extended_atype,
35-
nlist,
36-
mapping=mapping,
37-
)
38-
return descrpt, rot_mat, g2, h2, sw
17+
pass
Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
64
from deepmd.pt_expt.common import (
75
torch_module,
@@ -14,23 +12,4 @@
1412
@BaseDescriptor.register("se_e3_tebd")
1513
@torch_module
1614
class DescrptSeTTebd(DescrptSeTTebdDP):
17-
def forward(
18-
self,
19-
extended_coord: torch.Tensor,
20-
extended_atype: torch.Tensor,
21-
nlist: torch.Tensor,
22-
mapping: torch.Tensor | None = None,
23-
) -> tuple[
24-
torch.Tensor,
25-
torch.Tensor | None,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
]:
30-
descrpt, rot_mat, g2, h2, sw = self.call(
31-
extended_coord,
32-
extended_atype,
33-
nlist,
34-
mapping=mapping,
35-
)
36-
return descrpt, rot_mat, g2, h2, sw
15+
pass

deepmd/pt_expt/descriptor/se_t_tebd_block.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.descriptor.se_t_tebd import (
64
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
75
)
@@ -13,29 +11,7 @@
1311

1412
@torch_module
1513
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
16-
def forward(
17-
self,
18-
nlist: torch.Tensor,
19-
coord_ext: torch.Tensor,
20-
atype_ext: torch.Tensor,
21-
atype_embd_ext: torch.Tensor | None = None,
22-
mapping: torch.Tensor | None = None,
23-
type_embedding: torch.Tensor | None = None,
24-
) -> tuple[
25-
torch.Tensor,
26-
torch.Tensor | None,
27-
torch.Tensor | None,
28-
torch.Tensor | None,
29-
torch.Tensor | None,
30-
]:
31-
return self.call(
32-
nlist,
33-
coord_ext,
34-
atype_ext,
35-
atype_embd_ext=atype_embd_ext,
36-
mapping=mapping,
37-
type_embedding=type_embedding,
38-
)
14+
pass
3915

4016

4117
register_dpmodel_mapping(

deepmd/pt_expt/fitting/ener_fitting.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
64
from deepmd.pt_expt.common import (
75
register_dpmodel_mapping,
@@ -21,25 +19,7 @@ class EnergyFittingNet(EnergyFittingNetDP):
2119
This inherits from dpmodel EnergyFittingNet to get the correct serialize() method.
2220
"""
2321

24-
def forward(
25-
self,
26-
descriptor: torch.Tensor,
27-
atype: torch.Tensor,
28-
gr: torch.Tensor | None = None,
29-
g2: torch.Tensor | None = None,
30-
h2: torch.Tensor | None = None,
31-
fparam: torch.Tensor | None = None,
32-
aparam: torch.Tensor | None = None,
33-
) -> dict[str, torch.Tensor]:
34-
return self.call(
35-
descriptor,
36-
atype,
37-
gr=gr,
38-
g2=g2,
39-
h2=h2,
40-
fparam=fparam,
41-
aparam=aparam,
42-
)
22+
pass
4323

4424

4525
register_dpmodel_mapping(

deepmd/pt_expt/fitting/invar_fitting.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3-
import torch
4-
53
from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP
64
from deepmd.pt_expt.common import (
75
register_dpmodel_mapping,
@@ -15,25 +13,7 @@
1513
@BaseFitting.register("invar")
1614
@torch_module
1715
class InvarFitting(InvarFittingDP):
18-
def forward(
19-
self,
20-
descriptor: torch.Tensor,
21-
atype: torch.Tensor,
22-
gr: torch.Tensor | None = None,
23-
g2: torch.Tensor | None = None,
24-
h2: torch.Tensor | None = None,
25-
fparam: torch.Tensor | None = None,
26-
aparam: torch.Tensor | None = None,
27-
) -> dict[str, torch.Tensor]:
28-
return self.call(
29-
descriptor,
30-
atype,
31-
gr=gr,
32-
g2=g2,
33-
h2=h2,
34-
fparam=fparam,
35-
aparam=aparam,
36-
)
16+
pass
3717

3818

3919
register_dpmodel_mapping(

deepmd/pt_expt/utils/network.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,6 @@ def __init__(self, layers: list[dict] | None = None) -> None:
194194
super().__init__(layers)
195195
self.layers = torch.nn.ModuleList(self.layers)
196196

197-
def forward(self, x: torch.Tensor) -> torch.Tensor:
198-
return self.call(x)
199-
200197

201198
class EmbeddingNet(EmbeddingNetDP, torch.nn.Module):
202199
def __init__(self, *args: Any, **kwargs: Any) -> None:

0 commit comments

Comments
 (0)