We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 51985af commit 551721aCopy full SHA for 551721a
1 file changed
deepmd/jax/common.py
@@ -44,9 +44,12 @@ def to_jax_array(array: np.ndarray | None) -> jnp.ndarray | None:
44
return jnp.array(array)
45
46
47
+T = TypeVar('T')
48
+
49
50
def flax_module(
- module: type[NativeOP],
-) -> type[nnx.Module]:
51
+ module: type[T],
52
+) -> type[T]: # runtime: actually returns type[T & nnx.Module]
53
"""Convert a NativeOP to a Flax module.
54
55
Parameters
0 commit comments