|
20 | 20 | from absl import flags |
21 | 21 | import datetime |
22 | 22 | from etils import epath |
| 23 | +from flax import nnx |
23 | 24 | from flax.training import train_state |
24 | 25 | import jax |
25 | 26 | from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE |
@@ -521,7 +522,7 @@ def load_state_if_possible( |
521 | 522 | load_parameters_from_path: str, |
522 | 523 | load_full_state_from_path: str, |
523 | 524 | checkpoint_storage_concurrent_gb: int, |
524 | | - abstract_unboxed_pre_state: train_state.TrainState, |
| 525 | + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, |
525 | 526 | enable_single_replica_ckpt_restoring: bool | None = False, |
526 | 527 | dataset_type: str | None = "tfds", |
527 | 528 | step: int = -1, # -1 means latest |
@@ -589,8 +590,13 @@ def map_to_pspec(data): |
589 | 590 | ) |
590 | 591 | ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) |
591 | 592 |
|
592 | | - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) |
593 | | - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) |
| 593 | + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX |
| 594 | + restore_target = abstract_unboxed_pre_state |
| 595 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 596 | + restore_target = abstract_unboxed_pre_state.to_pure_dict() |
| 597 | + |
| 598 | + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) |
| 599 | + checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args) |
594 | 600 |
|
595 | 601 | match (checkpoint_manager, dataset_type, data_iterator): |
596 | 602 | # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager |
@@ -625,9 +631,14 @@ def map_to_pspec(data): |
625 | 631 | return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) |
626 | 632 |
|
627 | 633 | if load_parameters_from_path != "": |
| 634 | + if isinstance(abstract_unboxed_pre_state, nnx.State): |
| 635 | + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) |
| 636 | + else: |
| 637 | + params = abstract_unboxed_pre_state.params |
| 638 | + |
628 | 639 | restored_params = load_params_from_path( |
629 | 640 | load_parameters_from_path, |
630 | | - abstract_unboxed_pre_state.params, |
| 641 | + params, |
631 | 642 | checkpoint_storage_concurrent_gb, |
632 | 643 | use_ocdbt=use_ocdbt, |
633 | 644 | use_zarr3=use_zarr3, |
@@ -719,7 +730,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step |
719 | 730 | # Determine the effective step for saving a checkpoint. |
720 | 731 | # If 'step' is not provided, this call is for a potential final checkpoint |
721 | 732 | # and use the last completed step from the state. |
722 | | - actual_step = (int(state.step) - 1) if step is None else int(step) |
| 733 | + if step is not None: |
| 734 | + actual_step = int(step) |
| 735 | + else: |
| 736 | + if config.pure_nnx: |
| 737 | + actual_step = int(state.optimizer.step) - 1 |
| 738 | + else: |
| 739 | + # Linen TrainState has .step attribute |
| 740 | + actual_step = int(state.step) - 1 |
| 741 | + |
| 742 | + if config.pure_nnx: |
| 743 | + # Convert nnx.State to dict. |
| 744 | + state = state.to_pure_dict() |
723 | 745 |
|
724 | 746 | # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. |
725 | 747 | # This occurs if this function was called: |
|
0 commit comments