@@ -385,6 +385,20 @@ def torch_module(
385385) -> type [torch .nn .Module ]:
386386 """Convert a NativeOP to a torch.nn.Module.
387387
388+ This decorator wraps a NativeOP class to make it a PyTorch module, handling
389+ initialization, attribute setting, and method delegation automatically.
390+
391+ **Auto-generated methods:**
392+
393+ - If the wrapped class has a ``call()`` method but does not explicitly define
394+ ``forward()``, a ``forward()`` method will be auto-generated that delegates
395+ to ``call()``.
396+ - If the wrapped class has a ``call_lower()`` method but does not explicitly
397+ define ``forward_lower()``, a ``forward_lower()`` method will be auto-generated
398+ that delegates to ``call_lower()``.
399+ - Explicit ``forward()`` or ``forward_lower()`` definitions in the wrapped class
400+ are always respected and will not be overridden.
401+
388402 Parameters
389403 ----------
390404 module : type[NativeOP]
@@ -393,13 +407,13 @@ def torch_module(
393407 Returns
394408 -------
395409 type[torch.nn.Module]
396- The torch.nn.Module.
410+ The torch.nn.Module with auto-generated delegation methods if applicable .
397411
398412 Examples
399413 --------
400414 >>> @torch_module
401415 ... class MyModule(NativeOP):
402- ... pass
416+ ... pass # forward() auto-generated from call() if it exists
403417 """
404418
405419 @wraps (module , updated = ())
@@ -426,6 +440,22 @@ def __setattr__(self, name: str, value: Any) -> None:
426440 if not handled :
427441 super ().__setattr__ (name , value )
428442
443+ # Auto-generate forward -> call redirect if not explicitly defined
444+ if hasattr (module , "call" ) and "forward" not in module .__dict__ :
445+
446+ def forward (self , * args : Any , ** kwargs : Any ) -> Any : # noqa: ANN001
447+ return self .call (* args , ** kwargs )
448+
449+ TorchModule .forward = forward
450+
451+ # Auto-generate forward_lower -> call_lower redirect if not explicitly defined
452+ if hasattr (module , "call_lower" ) and "forward_lower" not in module .__dict__ :
453+
454+ def forward_lower (self , * args : Any , ** kwargs : Any ) -> Any : # noqa: ANN001
455+ return self .call_lower (* args , ** kwargs )
456+
457+ TorchModule .forward_lower = forward_lower
458+
429459 return TorchModule
430460
431461
0 commit comments