|
20 | 20 | from absl import flags |
21 | 21 | import datetime |
22 | 22 | from etils import epath |
| 23 | +from flax import nnx |
23 | 24 | from flax.training import train_state |
24 | 25 | import jax |
25 | 26 | from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE |
@@ -532,7 +533,7 @@ def load_state_if_possible( |
532 | 533 | load_parameters_from_path: str, |
533 | 534 | load_full_state_from_path: str, |
534 | 535 | checkpoint_storage_concurrent_gb: int, |
535 | | - abstract_unboxed_pre_state: train_state.TrainState, |
| 536 | + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, |
536 | 537 | enable_single_replica_ckpt_restoring: bool | None = False, |
537 | 538 | dataset_type: str | None = "tfds", |
538 | 539 | step: int = -1, # -1 means latest |
@@ -600,8 +601,13 @@ def map_to_pspec(data): |
600 | 601 | ) |
601 | 602 | ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) |
602 | 603 |
|
603 | | - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) |
604 | | - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) |
| 604 | + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX |
| 605 | + restore_target = abstract_unboxed_pre_state |
| 606 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 607 | + restore_target = abstract_unboxed_pre_state.to_pure_dict() |
| 608 | + |
| 609 | + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) |
| 610 | + checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args) |
605 | 611 |
|
606 | 612 | match (checkpoint_manager, dataset_type, data_iterator): |
607 | 613 | # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager |
@@ -636,9 +642,14 @@ def map_to_pspec(data): |
636 | 642 | return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) |
637 | 643 |
|
638 | 644 | if load_parameters_from_path != "": |
| 645 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 646 | + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) |
| 647 | + else: |
| 648 | + params = abstract_unboxed_pre_state.params |
| 649 | + |
639 | 650 | restored_params = load_params_from_path( |
640 | 651 | load_parameters_from_path, |
641 | | - abstract_unboxed_pre_state.params, |
| 652 | + params, |
642 | 653 | checkpoint_storage_concurrent_gb, |
643 | 654 | use_ocdbt=use_ocdbt, |
644 | 655 | use_zarr3=use_zarr3, |
@@ -730,7 +741,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step |
730 | 741 | # Determine the effective step for saving a checkpoint. |
731 | 742 | # If 'step' is not provided, this call is for a potential final checkpoint |
732 | 743 | # and use the last completed step from the state. |
733 | | - actual_step = (int(state.step) - 1) if step is None else int(step) |
| 744 | + if step is not None: |
| 745 | + actual_step = int(step) |
| 746 | + else: |
| 747 | + if config.pure_nnx: |
| 748 | + actual_step = int(state.optimizer.step) - 1 |
| 749 | + else: |
| 750 | + # Linen TrainState has .step attribute |
| 751 | + actual_step = int(state.step) - 1 |
| 752 | + |
| 753 | + if config.pure_nnx: |
| 754 | + # Convert nnx.State to dict. |
| 755 | + state = state.to_pure_dict() |
734 | 756 |
|
735 | 757 | # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. |
736 | 758 | # This occurs if this function was called: |
|
0 commit comments