diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 36504fad74..189207835c 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -658,14 +658,17 @@ def __init__( super().__init__(root_directory=root_directory, options=options) self.student_config = student_config self._iterator = raw_iterator + self._checkpoint_manager: checkpoint.CheckpointManager | None = None # Re-initialize internal Orbax manager with MaxText's Grain handler # pylint: disable=access-member-before-definition - if self._checkpoint_manager is not None: - root_directory = self._checkpoint_manager.directory + if self._checkpointer is not None: + root_directory = self._checkpointer.directory if options is None: - options = getattr(self._checkpoint_manager, "options", None) + options = getattr(self._checkpointer, "options", None) or getattr( + self._checkpointer._manager, "options", None + ) item_handlers = { "model_params": checkpoint.PyTreeCheckpointHandler(), @@ -675,12 +678,13 @@ def __init__( "iter": GrainCheckpointHandler(), } - self._checkpoint_manager.close() - self._checkpoint_manager = checkpoint.CheckpointManager( + self._checkpointer._manager.close() + self._checkpointer._manager = checkpoint.CheckpointManager( root_directory, item_handlers=item_handlers, options=options, ) + self._checkpoint_manager = self._checkpointer._manager # pylint: enable=access-member-before-definition def save(