Commit 6b6e61b
committed
NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes
Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
(callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
nnx.split/merge pattern; teacher inference moved outside value_and_grad1 parent fe7d3e5 commit 6b6e61b
13 files changed
Lines changed: 2823 additions & 81 deletions
File tree
- src/maxtext
- checkpoint_conversion
- models
- optimizers
- trainers/post_train
- distillation
- rl
- sft
- utils
- tests
- post_training/unit
- unit
- utils
Lines changed: 609 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
520 | 520 | | |
521 | 521 | | |
522 | 522 | | |
523 | | - | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
524 | 528 | | |
525 | 529 | | |
526 | 530 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
336 | 336 | | |
337 | 337 | | |
338 | 338 | | |
339 | | - | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
340 | 342 | | |
341 | 343 | | |
342 | 344 | | |
| |||
Lines changed: 46 additions & 35 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
267 | 267 | | |
268 | 268 | | |
269 | 269 | | |
270 | | - | |
| 270 | + | |
271 | 271 | | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
272 | 278 | | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
289 | 299 | | |
290 | | - | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
291 | 303 | | |
| 304 | + | |
| 305 | + | |
292 | 306 | | |
293 | | - | |
| 307 | + | |
294 | 308 | | |
295 | 309 | | |
296 | 310 | | |
| |||
299 | 313 | | |
300 | 314 | | |
301 | 315 | | |
302 | | - | |
303 | 316 | | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | | - | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
313 | 321 | | |
314 | | - | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
315 | 329 | | |
316 | 330 | | |
317 | 331 | | |
318 | 332 | | |
319 | 333 | | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | 334 | | |
324 | | - | |
325 | | - | |
| 335 | + | |
| 336 | + | |
326 | 337 | | |
327 | 338 | | |
328 | 339 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
59 | 95 | | |
60 | 96 | | |
61 | 97 | | |
| |||
543 | 579 | | |
544 | 580 | | |
545 | 581 | | |
| 582 | + | |
| 583 | + | |
546 | 584 | | |
547 | 585 | | |
548 | 586 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
38 | | - | |
| 38 | + | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
68 | 69 | | |
69 | 70 | | |
70 | 71 | | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
71 | 135 | | |
72 | 136 | | |
73 | 137 | | |
| |||
161 | 225 | | |
162 | 226 | | |
163 | 227 | | |
164 | | - | |
| 228 | + | |
165 | 229 | | |
166 | 230 | | |
167 | 231 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1800 | 1800 | | |
1801 | 1801 | | |
1802 | 1802 | | |
1803 | | - | |
1804 | | - | |
1805 | | - | |
1806 | | - | |
1807 | | - | |
1808 | | - | |
| 1803 | + | |
| 1804 | + | |
| 1805 | + | |
| 1806 | + | |
| 1807 | + | |
| 1808 | + | |
| 1809 | + | |
1809 | 1810 | | |
1810 | 1811 | | |
1811 | 1812 | | |
1812 | | - | |
1813 | 1813 | | |
1814 | | - | |
1815 | | - | |
1816 | | - | |
1817 | | - | |
1818 | | - | |
1819 | | - | |
1820 | | - | |
1821 | | - | |
1822 | | - | |
| 1814 | + | |
| 1815 | + | |
| 1816 | + | |
| 1817 | + | |
| 1818 | + | |
| 1819 | + | |
| 1820 | + | |
| 1821 | + | |
| 1822 | + | |
| 1823 | + | |
| 1824 | + | |
| 1825 | + | |
| 1826 | + | |
| 1827 | + | |
| 1828 | + | |
| 1829 | + | |
| 1830 | + | |
| 1831 | + | |
| 1832 | + | |
| 1833 | + | |
| 1834 | + | |
| 1835 | + | |
| 1836 | + | |
| 1837 | + | |
1823 | 1838 | | |
1824 | 1839 | | |
1825 | 1840 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
372 | 372 | | |
373 | 373 | | |
374 | 374 | | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
375 | 382 | | |
376 | 383 | | |
377 | 384 | | |
| |||
0 commit comments