Skip to content

Commit 6703f13

Browse files
Copilotnjzjz
andcommitted
docs(pt_expt): improve torch_module docstring and update tests
- Enhanced torch_module decorator docstring to document auto-generation behavior - Updated tests to use module(...) instead of module.forward(...) to test the full invocation path through torch.nn.Module.__call__ - This ensures tests cover hooks, tracing, and export behavior properly Addresses review feedback from PR #5246 Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 2153852 commit 6703f13

2 files changed

Lines changed: 18 additions & 4 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,20 @@ def torch_module(
300300
) -> type[torch.nn.Module]:
301301
"""Convert a NativeOP to a torch.nn.Module.
302302
303+
This decorator wraps a NativeOP class to make it a PyTorch module, handling
304+
initialization, attribute setting, and method delegation automatically.
305+
306+
**Auto-generated methods:**
307+
308+
- If the wrapped class has a ``call()`` method but does not explicitly define
309+
``forward()``, a ``forward()`` method will be auto-generated that delegates
310+
to ``call()``.
311+
- If the wrapped class has a ``call_lower()`` method but does not explicitly
312+
define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated
313+
that delegates to ``call_lower()``.
314+
- Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class
315+
are always respected and will not be overridden.
316+
303317
Parameters
304318
----------
305319
module : type[NativeOP]
@@ -308,13 +322,13 @@ def torch_module(
308322
Returns
309323
-------
310324
type[torch.nn.Module]
311-
The torch.nn.Module.
325+
The torch.nn.Module with auto-generated delegation methods if applicable.
312326
313327
Examples
314328
--------
315329
>>> @torch_module
316330
... class MyModule(NativeOP):
317-
... pass
331+
... pass # forward() auto-generated from call() if it exists
318332
"""
319333

320334
@wraps(module, updated=())

source/tests/pt_expt/utils/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class MockModule(MockNativeOP):
4040

4141
module = MockModule()
4242
input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu"))
43-
output = module.forward(input_tensor)
43+
output = module(input_tensor)
4444
expected = input_tensor * 2
4545
assert torch.allclose(output, expected)
4646

@@ -82,7 +82,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8282

8383
module = MockModule()
8484
input_tensor = torch.tensor([1.0, 2.0, 3.0], device=torch.device("cpu"))
85-
output = module.forward(input_tensor)
85+
output = module(input_tensor)
8686
expected = input_tensor * 3 # Should use the explicit forward, not call()
8787
assert torch.allclose(output, expected)
8888

0 commit comments

Comments
 (0)