Skip to content

Commit 51985af

Browse files
committed
fix(jax): improve JAX modules' names
Use `wraps` to keep the modules' names, so they won't be `FlaxModule`, which cannot be regonized. I realized it when implementing #5213.
1 parent 97d8ded commit 51985af

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

deepmd/jax/common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from functools import (
3+
wraps,
4+
)
25
from typing import (
36
Any,
47
overload,
@@ -42,18 +45,18 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None:
4245

4346

4447
def flax_module(
45-
module: NativeOP,
46-
) -> nnx.Module:
48+
module: type[NativeOP],
49+
) -> type[nnx.Module]:
4750
"""Convert a NativeOP to a Flax module.
4851
4952
Parameters
5053
----------
51-
module : NativeOP
54+
module : type[NativeOP]
5255
The NativeOP to convert.
5356
5457
Returns
5558
-------
56-
flax.nnx.Module
59+
type[flax.nnx.Module]
5760
The Flax module.
5861
5962
Examples
@@ -72,6 +75,7 @@ class MixedMetaClass(*metas):
7275
def __call__(self, *args: Any, **kwargs: Any) -> Any:
7376
return type(nnx.Module).__call__(self, *args, **kwargs)
7477

78+
@wraps(module, updated=())
7579
class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
7680
def __init_subclass__(cls, **kwargs: Any) -> None:
7781
return super().__init_subclass__(**kwargs)

0 commit comments

Comments
 (0)