|
20 | 20 | from absl import flags |
21 | 21 | import datetime |
22 | 22 | from etils import epath |
| 23 | +from flax import nnx |
23 | 24 | from flax.training import train_state |
24 | 25 | import jax |
25 | 26 | from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE |
@@ -536,7 +537,7 @@ def load_state_if_possible( |
536 | 537 | load_parameters_from_path: str, |
537 | 538 | load_full_state_from_path: str, |
538 | 539 | checkpoint_storage_concurrent_gb: int, |
539 | | - abstract_unboxed_pre_state: train_state.TrainState, |
| 540 | + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, |
540 | 541 | enable_single_replica_ckpt_restoring: bool | None = False, |
541 | 542 | dataset_type: str | None = "tfds", |
542 | 543 | step: int = -1, # -1 means latest |
@@ -604,8 +605,13 @@ def map_to_pspec(data): |
604 | 605 | ) |
605 | 606 | ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) |
606 | 607 |
|
607 | | - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) |
608 | | - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) |
| 608 | + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX |
| 609 | + restore_target = abstract_unboxed_pre_state |
| 610 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 611 | + restore_target = abstract_unboxed_pre_state.to_pure_dict() |
| 612 | + |
| 613 | + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) |
| 614 | + checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args) |
609 | 615 |
|
610 | 616 | match (checkpoint_manager, dataset_type, data_iterator): |
611 | 617 | # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager |
@@ -640,9 +646,14 @@ def map_to_pspec(data): |
640 | 646 | return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) |
641 | 647 |
|
642 | 648 | if load_parameters_from_path != "": |
| 649 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 650 | + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) |
| 651 | + else: |
| 652 | + params = abstract_unboxed_pre_state.params |
| 653 | + |
643 | 654 | restored_params = load_params_from_path( |
644 | 655 | load_parameters_from_path, |
645 | | - abstract_unboxed_pre_state.params, |
| 656 | + params, |
646 | 657 | checkpoint_storage_concurrent_gb, |
647 | 658 | use_ocdbt=use_ocdbt, |
648 | 659 | use_zarr3=use_zarr3, |
@@ -734,7 +745,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step |
734 | 745 | # Determine the effective step for saving a checkpoint. |
735 | 746 | # If 'step' is not provided, this call is for a potential final checkpoint |
736 | 747 | # and use the last completed step from the state. |
737 | | - actual_step = (int(state.step) - 1) if step is None else int(step) |
| 748 | + if step is not None: |
| 749 | + actual_step = int(step) |
| 750 | + else: |
| 751 | + if config.pure_nnx: |
| 752 | + actual_step = int(state.optimizer.step) - 1 |
| 753 | + else: |
| 754 | + # Linen TrainState has .step attribute |
| 755 | + actual_step = int(state.step) - 1 |
| 756 | + |
| 757 | + if config.pure_nnx: |
| 758 | + # Convert nnx.State to dict. |
| 759 | + state = state.to_pure_dict() |
738 | 760 |
|
739 | 761 | # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. |
740 | 762 | # This occurs if this function was called: |
|
0 commit comments