Commit 28b1e4a
committed
NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes
Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
(callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
nnx.split/merge pattern; teacher inference moved outside value_and_grad1 parent 2e9d0e9 commit 28b1e4a
14 files changed
Lines changed: 2863 additions & 89 deletions
File tree
- src/maxtext
- checkpoint_conversion
- models
- optimizers
- trainers/post_train
- distillation
- rl
- sft
- utils
- tests
- post_training/unit
- unit
- utils
Lines changed: 609 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
520 | 520 | | |
521 | 521 | | |
522 | 522 | | |
523 | | - | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
524 | 528 | | |
525 | 529 | | |
526 | 530 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
336 | 336 | | |
337 | 337 | | |
338 | 338 | | |
339 | | - | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
340 | 342 | | |
341 | 343 | | |
342 | 344 | | |
| |||
Lines changed: 45 additions & 33 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
254 | | - | |
| 254 | + | |
255 | 255 | | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
256 | 262 | | |
| 263 | + | |
| 264 | + | |
257 | 265 | | |
258 | 266 | | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
273 | 284 | | |
274 | | - | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
275 | 288 | | |
| 289 | + | |
| 290 | + | |
276 | 291 | | |
277 | | - | |
| 292 | + | |
278 | 293 | | |
279 | 294 | | |
280 | 295 | | |
| |||
283 | 298 | | |
284 | 299 | | |
285 | 300 | | |
286 | | - | |
287 | 301 | | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | | - | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | | - | |
296 | | - | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
297 | 306 | | |
298 | | - | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
299 | 314 | | |
300 | 315 | | |
301 | 316 | | |
302 | 317 | | |
303 | 318 | | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | 319 | | |
308 | | - | |
309 | | - | |
| 320 | + | |
| 321 | + | |
310 | 322 | | |
311 | 323 | | |
312 | 324 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
58 | 94 | | |
59 | 95 | | |
60 | 96 | | |
| |||
421 | 457 | | |
422 | 458 | | |
423 | 459 | | |
| 460 | + | |
| 461 | + | |
424 | 462 | | |
425 | 463 | | |
426 | 464 | | |
| |||
568 | 606 | | |
569 | 607 | | |
570 | 608 | | |
571 | | - | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
572 | 613 | | |
573 | 614 | | |
574 | 615 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
38 | | - | |
| 38 | + | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
68 | 69 | | |
69 | 70 | | |
70 | 71 | | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
71 | 136 | | |
72 | 137 | | |
73 | 138 | | |
| |||
109 | 174 | | |
110 | 175 | | |
111 | 176 | | |
| 177 | + | |
112 | 178 | | |
113 | 179 | | |
114 | 180 | | |
| |||
162 | 228 | | |
163 | 229 | | |
164 | 230 | | |
165 | | - | |
| 231 | + | |
166 | 232 | | |
167 | 233 | | |
168 | 234 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1852 | 1852 | | |
1853 | 1853 | | |
1854 | 1854 | | |
1855 | | - | |
1856 | | - | |
1857 | | - | |
1858 | | - | |
1859 | | - | |
1860 | | - | |
| 1855 | + | |
| 1856 | + | |
| 1857 | + | |
| 1858 | + | |
| 1859 | + | |
| 1860 | + | |
| 1861 | + | |
1861 | 1862 | | |
1862 | 1863 | | |
1863 | 1864 | | |
1864 | | - | |
1865 | 1865 | | |
1866 | | - | |
1867 | | - | |
1868 | | - | |
1869 | | - | |
1870 | | - | |
1871 | | - | |
1872 | | - | |
1873 | | - | |
1874 | | - | |
| 1866 | + | |
| 1867 | + | |
| 1868 | + | |
| 1869 | + | |
| 1870 | + | |
| 1871 | + | |
| 1872 | + | |
| 1873 | + | |
| 1874 | + | |
| 1875 | + | |
| 1876 | + | |
| 1877 | + | |
| 1878 | + | |
| 1879 | + | |
| 1880 | + | |
| 1881 | + | |
| 1882 | + | |
| 1883 | + | |
| 1884 | + | |
| 1885 | + | |
| 1886 | + | |
| 1887 | + | |
| 1888 | + | |
| 1889 | + | |
1875 | 1890 | | |
1876 | 1891 | | |
1877 | 1892 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
546 | 546 | | |
547 | 547 | | |
548 | 548 | | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
549 | 556 | | |
550 | 557 | | |
551 | 558 | | |
| |||
0 commit comments