Commit 82798d8
committed
fix(mtp): resolve parameter scoping in Multi-Token Prediction block
The MTP transformer layer parameters were incorrectly attaching to the top-level Linen scope (`mtp_block`) instead of the nested NNX scope (`mtp_layer_1`). This occurred because the parent model passed `ToLinen` wrappers into the MTP block, causing the Flax NNX tracer to ignore the nested layers and default to Linen's automatic scoping.
Changes:
* src/maxtext/layers/nnx_wrappers.py: Exposed the underlying native NNX class in the `to_linen_class` factory via a new `module_class` attribute.
* src/maxtext/models/models.py: Dynamically unwrapped the transformer layer using `getattr(mtp_layer_linen, "module_class")` before injecting it into the MTP block to preserve NNX tracing.
* src/maxtext/layers/multi_token_prediction.py: Removed the legacy Linen `name=` argument from the transformer layer instantiation and added required native NNX arguments (`quant=None`, `layer_idx=-1`).1 parent 84bf0df commit 82798d8
3 files changed
Lines changed: 7 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
79 | 79 | | |
80 | 80 | | |
81 | 81 | | |
82 | | - | |
83 | 82 | | |
84 | 83 | | |
85 | 84 | | |
| |||
112 | 111 | | |
113 | 112 | | |
114 | 113 | | |
115 | | - | |
116 | 114 | | |
117 | 115 | | |
118 | 116 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
627 | 627 | | |
628 | 628 | | |
629 | 629 | | |
| 630 | + | |
630 | 631 | | |
631 | 632 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
93 | 93 | | |
94 | 94 | | |
95 | 95 | | |
96 | | - | |
97 | 96 | | |
98 | 97 | | |
99 | | - | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
100 | 103 | | |
101 | 104 | | |
102 | 105 | | |
103 | | - | |
| 106 | + | |
104 | 107 | | |
105 | 108 | | |
106 | 109 | | |
| |||
0 commit comments