Commit 5f067bb
committed
NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes
Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities:
- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
Part 2 — post-training bug fixes:
- models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the
whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields)
- optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams
(callable() check before invoking learning_rate_fn)
- train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit
raises conflicting outer_index error); refactored to jax.value_and_grad + explicit
nnx.split/merge pattern; teacher inference moved outside value_and_grad1 parent 29bbe47 commit 5f067bb
16 files changed
Lines changed: 2871 additions & 92 deletions
File tree
- src
- dependencies/extra_deps
- maxtext
- checkpoint_conversion
- models
- optimizers
- trainers/post_train
- distillation
- rl
- sft
- utils
- tests
- post_training/unit
- unit
- utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | | - | |
| 1 | + | |
0 commit comments