|
20 | 20 | from absl import flags |
21 | 21 | import datetime |
22 | 22 | from etils import epath |
| 23 | +from flax import nnx |
23 | 24 | from flax.training import train_state |
24 | 25 | import jax |
25 | 26 | from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE |
@@ -536,7 +537,7 @@ def load_state_if_possible( |
536 | 537 | load_parameters_from_path: str, |
537 | 538 | load_full_state_from_path: str, |
538 | 539 | checkpoint_storage_concurrent_gb: int, |
539 | | - abstract_unboxed_pre_state: train_state.TrainState, |
| 540 | + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, |
540 | 541 | enable_single_replica_ckpt_restoring: bool | None = False, |
541 | 542 | dataset_type: str | None = "tfds", |
542 | 543 | step: int = -1, # -1 means latest |
@@ -604,9 +605,14 @@ def map_to_pspec(data): |
604 | 605 | ) |
605 | 606 | ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) |
606 | 607 |
|
607 | | - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) |
| 608 | + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX |
| 609 | + restore_target = abstract_unboxed_pre_state |
| 610 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 611 | + restore_target = abstract_unboxed_pre_state.to_pure_dict() |
| 612 | + |
| 613 | + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) |
608 | 614 | checkpoint_args = ocp.args.PyTreeRestore( |
609 | | - item=abstract_unboxed_pre_state, |
| 615 | + item=restore_target, |
610 | 616 | restore_args=restore_args, |
611 | 617 | partial_restore=True, |
612 | 618 | ) |
@@ -647,9 +653,14 @@ def map_to_pspec(data): |
647 | 653 | return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) |
648 | 654 |
|
649 | 655 | if load_parameters_from_path != "": |
| 656 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 657 | + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) |
| 658 | + else: |
| 659 | + params = abstract_unboxed_pre_state.params |
| 660 | + |
650 | 661 | restored_params = load_params_from_path( |
651 | 662 | load_parameters_from_path, |
652 | | - abstract_unboxed_pre_state.params, |
| 663 | + params, |
653 | 664 | checkpoint_storage_concurrent_gb, |
654 | 665 | use_ocdbt=use_ocdbt, |
655 | 666 | use_zarr3=use_zarr3, |
@@ -741,7 +752,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step |
741 | 752 | # Determine the effective step for saving a checkpoint. |
742 | 753 | # If 'step' is not provided, this call is for a potential final checkpoint |
743 | 754 | # and use the last completed step from the state. |
744 | | - actual_step = (int(state.step) - 1) if step is None else int(step) |
| 755 | + if step is not None: |
| 756 | + actual_step = int(step) |
| 757 | + else: |
| 758 | + if config.pure_nnx: |
| 759 | + actual_step = int(state.optimizer.step) - 1 |
| 760 | + else: |
| 761 | + # Linen TrainState has .step attribute |
| 762 | + actual_step = int(state.step) - 1 |
| 763 | + |
| 764 | + if config.pure_nnx: |
| 765 | + # Convert nnx.State to dict. |
| 766 | + state = state.to_pure_dict() |
745 | 767 |
|
746 | 768 | # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. |
747 | 769 | # This occurs if this function was called: |
|
0 commit comments