Commit 02bd1fc
fix(jax): improve JAX modules' names (#5214)
Use `wraps` to keep the modules' names, so they won't be `FlaxModule`,
which cannot be regonized. I realized it when implementing #5213.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Refactor**
* Improved type annotations and type-safety for module wrapping to aid
correctness and tooling.
* Preserved original module metadata when creating wrapped modules,
improving introspection and debugging.
* Updated docstrings and signatures to reflect the new typing and
metadata behavior.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 3948cc7 commit 02bd1fc
1 file changed
Lines changed: 12 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
2 | 5 | | |
3 | 6 | | |
| 7 | + | |
4 | 8 | | |
5 | 9 | | |
6 | 10 | | |
7 | 11 | | |
8 | 12 | | |
9 | | - | |
10 | | - | |
11 | | - | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
44 | 48 | | |
45 | | - | |
46 | | - | |
| 49 | + | |
| 50 | + | |
47 | 51 | | |
48 | 52 | | |
49 | 53 | | |
50 | 54 | | |
51 | | - | |
| 55 | + | |
52 | 56 | | |
53 | 57 | | |
54 | 58 | | |
55 | 59 | | |
56 | | - | |
| 60 | + | |
57 | 61 | | |
58 | 62 | | |
59 | 63 | | |
| |||
72 | 76 | | |
73 | 77 | | |
74 | 78 | | |
| 79 | + | |
75 | 80 | | |
76 | 81 | | |
77 | 82 | | |
| |||
0 commit comments