File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ from functools import (
3+ wraps ,
4+ )
25from typing import (
36 Any ,
47 overload ,
@@ -42,18 +45,18 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None:
4245
4346
4447def flax_module (
45- module : NativeOP ,
46- ) -> nnx .Module :
48+ module : type [ NativeOP ] ,
49+ ) -> type [ nnx .Module ] :
4750 """Convert a NativeOP to a Flax module.
4851
4952 Parameters
5053 ----------
51- module : NativeOP
54+ module : type[ NativeOP]
5255 The NativeOP to convert.
5356
5457 Returns
5558 -------
56- flax.nnx.Module
59+ type[ flax.nnx.Module]
5760 The Flax module.
5861
5962 Examples
@@ -72,6 +75,7 @@ class MixedMetaClass(*metas):
7275 def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
7376 return type (nnx .Module ).__call__ (self , * args , ** kwargs )
7477
78+ @wraps (module , updated = ())
7579 class FlaxModule (module , nnx .Module , metaclass = MixedMetaClass ):
7680 def __init_subclass__ (cls , ** kwargs : Any ) -> None :
7781 return super ().__init_subclass__ (** kwargs )
You can’t perform that action at this time.
0 commit comments