Commit 4ddc37d
feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator (#5246)
- [x] Understand the current implementation of `torch_module` decorator
- [x] Identify all classes using `@torch_module` with manual `forward()`
methods
- [x] Modify `torch_module` decorator to auto-generate `forward()` from
`call()`
- [x] Modify `torch_module` decorator to auto-generate `forward_lower()`
from `call_lower()`
- [x] Remove boilerplate `forward()` methods from descriptor classes
(DescrptSeA, DescrptSeR, DescrptSeT, DescrptSeTTebd,
DescrptBlockSeTTebd)
- [x] Remove boilerplate `forward()` methods from fitting classes
(InvarFitting, EnergyFittingNet)
- [x] Remove boilerplate `forward()` method from NativeNet
- [x] Run linting and formatting
- [x] Create tests to verify auto-generated methods work correctly
- [x] Run targeted tests to verify changes (63 pt_expt tests pass)
- [x] Manually verify forward methods work with real data
- [x] Request code review and address feedback
- [x] Update docstring to document auto-generation behavior
- [x] Update tests to use module(...) invocation path
- [x] Ready for final approval
<!-- START COPILOT ORIGINAL PROMPT -->
<details>
<summary>Original prompt</summary>
>
> ----
>
> *This section details on the original issue you should resolve*
>
> <issue_title>pt_expt: auto-generate forward/forward_lower in
torch_module decorator to avoid boilerplate</issue_title>
> <issue_description>## Summary
>
> The `torch_module` decorator in `deepmd/pt_expt/common.py` currently
handles `__init__`, `__call__`, and `__setattr__` automatically, but
each wrapped class still needs to manually define `forward()` (and
potentially `forward_lower()`) methods that simply redirect to `call()`
(and `call_lower()`).
>
> ## Current situation
>
> Every descriptor class in `deepmd/pt_expt/descriptor/` repeats the
same boilerplate pattern:
>
> ```python
> @torch_module
> class DescrptSeA(DescrptSeADP):
> def forward(
> self,
> extended_coord: torch.Tensor,
> extended_atype: torch.Tensor,
> nlist: torch.Tensor,
> mapping: torch.Tensor | None = None,
> ) -> tuple[...]:
> descrpt, rot_mat, g2, h2, sw = self.call(
> extended_coord,
> extended_atype,
> nlist,
> mapping=mapping,
> )
> return descrpt, rot_mat, g2, h2, sw
> ```
>
> This identical `forward → call` redirect is duplicated across
`DescrptSeA`, `DescrptSeR`, `DescrptSeT`, `DescrptSeTTebd`,
`DescrptBlockSeTTebd`, and similarly for fitting classes like
`InvarFitting` (though `InvarFitting` does not use `torch_module`).
>
> ## Proposal
>
> Modify the `torch_module` decorator to **automatically generate
`forward` and `forward_lower`** methods that delegate to `call` and
`call_lower` respectively, if:
>
> 1. The wrapped class has a `call` method (from the dpmodel base) but
does **not** define its own `forward`.
> 2. Similarly for `call_lower` → `forward_lower`.
>
> This could be implemented by adding something like the following
inside `torch_module`:
>
> ```python
> def torch_module(module: type[NativeOP]) -> type[torch.nn.Module]:
> @wraps(module, updated=())
> class TorchModule(module, torch.nn.Module):
> def __init__(self, *args, **kwargs):
> torch.nn.Module.__init__(self)
> module.__init__(self, *args, **kwargs)
>
> def __call__(self, *args, **kwargs):
> return torch.nn.Module.__call__(self, *args, **kwargs)
>
> def __setattr__(self, name, value):
> handled, value = dpmodel_setattr(self, name, value)
> if not handled:
> super().__setattr__(name, value)
>
> # Auto-generate forward -> call redirect
> if hasattr(module, "call") and "forward" not in module.__dict__:
> def forward(self, *args, **kwargs):
> return self.call(*args, **kwargs)
> TorchModule.forward = forward
>
> # Auto-generate forward_lower -> call_lower redirect
> if hasattr(module, "call_lower") and "forward_lower" not in
module.__dict__:
> def forward_lower(self, *args, **kwargs):
> return self.call_lower(*args, **kwargs)
> TorchModule.forward_lower = forward_lower
>
> return TorchModule
> ```
>
> ## Benefits
>
> - **Eliminates boilerplate**: Descriptor/fitting wrapper classes
become minimal (just the `@torch_module` decorator +
`@BaseDescriptor.register` + class declaration).
> - **Reduces maintenance burden**: Adding new descriptors to `pt_expt`
requires fewer lines and less copy-paste.
> - **Consistent behavior**: The `forward ↔ call` contract is enforced
in one place rather than scattered across many files.
> - **Less error-prone**: No risk of forgetting to update `forward` when
`call` signature changes.
>
> ## Example after this change
>
> ```python
> @BaseDescriptor.register("se_e2_a")
> @BaseDescriptor.register("se_a")
> @torch_module
> class DescrptSeA(DescrptSeADP):
> pass # forward() auto-generated from call()
> ```
>
> ---
> *Submitted by OpenClaw on behalf of @njzjz*</issue_description>
>
> ## Comments on the Issue (you are @copilot in this section)
>
> <comments>
> </comments>
>
</details>
<!-- START COPILOT CODING AGENT SUFFIX -->
- Fixes #5245
<!-- START COPILOT CODING AGENT TIPS -->
---
💡 You can make Copilot smarter by setting up custom instructions,
customizing its development environment and configuring Model Context
Protocol (MCP) servers. Learn more [Copilot coding agent
tips](https://gh.io/copilot-coding-agent-tips) in the docs.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>1 parent 09e40bb commit 4ddc37d
10 files changed
Lines changed: 145 additions & 160 deletions
File tree
- deepmd/pt_expt
- descriptor
- fitting
- utils
- source/tests/pt_expt/utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
300 | 300 | | |
301 | 301 | | |
302 | 302 | | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
303 | 317 | | |
304 | 318 | | |
305 | 319 | | |
| |||
308 | 322 | | |
309 | 323 | | |
310 | 324 | | |
311 | | - | |
| 325 | + | |
312 | 326 | | |
313 | 327 | | |
314 | 328 | | |
315 | 329 | | |
316 | 330 | | |
317 | | - | |
| 331 | + | |
318 | 332 | | |
319 | 333 | | |
320 | 334 | | |
| |||
332 | 346 | | |
333 | 347 | | |
334 | 348 | | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
335 | 365 | | |
336 | 366 | | |
337 | 367 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
15 | 13 | | |
16 | 14 | | |
17 | 15 | | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
| 16 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
15 | 13 | | |
16 | 14 | | |
17 | 15 | | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
| 16 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
16 | 14 | | |
17 | 15 | | |
18 | 16 | | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
| 17 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
14 | 12 | | |
15 | 13 | | |
16 | 14 | | |
17 | | - | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
| 15 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
13 | 11 | | |
14 | 12 | | |
15 | 13 | | |
16 | | - | |
17 | | - | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
| 14 | + | |
39 | 15 | | |
40 | 16 | | |
41 | 17 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
21 | 19 | | |
22 | 20 | | |
23 | 21 | | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
| 22 | + | |
43 | 23 | | |
44 | 24 | | |
45 | 25 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | | - | |
5 | 3 | | |
6 | 4 | | |
7 | 5 | | |
| |||
15 | 13 | | |
16 | 14 | | |
17 | 15 | | |
18 | | - | |
19 | | - | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
| 16 | + | |
37 | 17 | | |
38 | 18 | | |
39 | 19 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
194 | 194 | | |
195 | 195 | | |
196 | 196 | | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | 197 | | |
201 | 198 | | |
202 | 199 | | |
| |||
0 commit comments