Skip to content

Commit 5f067bb

Browse files
committed
NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes
Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad
1 parent 29bbe47 commit 5f067bb

16 files changed

Lines changed: 2871 additions & 92 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
google-tunix @ https://github.com/google/tunix/archive/336d102fe32ca0edbe42a8f66ff0fd533cebdf52.zip
1+
google-tunix @ https://github.com/google/tunix/archive/110932a8395086511228483312131841521695c1.zip

0 commit comments

Comments
 (0)