Skip to content

Commit 02bd1fc

Browse files
fix(jax): improve JAX modules' names (#5214)
Use `wraps` to keep the modules' names, so they won't be `FlaxModule`, which cannot be regonized. I realized it when implementing #5213. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved type annotations and type-safety for module wrapping to aid correctness and tooling. * Preserved original module metadata when creating wrapped modules, improving introspection and debugging. * Updated docstrings and signatures to reflect the new typing and metadata behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3948cc7 commit 02bd1fc

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

deepmd/jax/common.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from functools import (
3+
wraps,
4+
)
25
from typing import (
36
Any,
7+
TypeVar,
48
overload,
59
)
610

711
import numpy as np
812

9-
from deepmd.dpmodel.common import (
10-
NativeOP,
11-
)
1213
from deepmd.jax.env import (
1314
jnp,
1415
nnx,
@@ -41,19 +42,22 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None:
4142
return jnp.array(array)
4243

4344

45+
T = TypeVar("T")
46+
47+
4448
def flax_module(
45-
module: NativeOP,
46-
) -> nnx.Module:
49+
module: type[T],
50+
) -> type[T]: # runtime: actually returns type[T & nnx.Module]
4751
"""Convert a NativeOP to a Flax module.
4852
4953
Parameters
5054
----------
51-
module : NativeOP
55+
module : type[NativeOP]
5256
The NativeOP to convert.
5357
5458
Returns
5559
-------
56-
flax.nnx.Module
60+
type[flax.nnx.Module]
5761
The Flax module.
5862
5963
Examples
@@ -72,6 +76,7 @@ class MixedMetaClass(*metas):
7276
def __call__(self, *args: Any, **kwargs: Any) -> Any:
7377
return type(nnx.Module).__call__(self, *args, **kwargs)
7478

79+
@wraps(module, updated=())
7580
class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
7681
def __init_subclass__(cls, **kwargs: Any) -> None:
7782
return super().__init_subclass__(**kwargs)

0 commit comments

Comments
 (0)