Skip to content

Commit 559312a

Browse files
Add autocheckpoint feature to MaxText.
PiperOrigin-RevId: 892064416
1 parent 5478bad commit 559312a

4 files changed

Lines changed: 20 additions & 3 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def create_orbax_checkpoint_manager(
221221
enable_single_controller: bool = False,
222222
colocated_python_checkpointing: bool = False,
223223
enable_single_replica_ckpt_restoring: bool = False,
224+
enable_autocheckpoint: bool = False,
224225
):
225226
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
226227
if not enable_checkpointing:
@@ -248,11 +249,19 @@ def create_orbax_checkpoint_manager(
248249
# local storage checkpoint needs parent directory created
249250
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250251
if enable_continuous_checkpointing:
252+
max_logging.log("Enabling policy for continuous checkpointing.")
251253
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252-
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
254+
elif enable_autocheckpoint:
255+
max_logging.log("Enabling policy for autocheckpoint.")
256+
save_decision_policy = save_decision_policy_lib.AnySavePolicy([
257+
save_decision_policy_lib.PreemptionCheckpointingPolicy(),
258+
save_decision_policy_lib.FixedIntervalPolicy(save_interval_steps),
259+
])
253260
else:
261+
max_logging.log("Enabling policy for fixed interval checkpointing.")
254262
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps)
255-
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
263+
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
264+
256265
async_options = None
257266
if enable_continuous_checkpointing:
258267
async_options = ocp.AsyncOptions(
@@ -752,6 +761,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
752761
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
753762
or (step % config.checkpoint_period == 0)
754763
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
764+
or (config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step))
755765
):
756766
blocking_until_ready_start = time.time()
757767
max_logging.log(f"Waiting for step {step} to finish before checkpoint...")

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ source_checkpoint_layout: "orbax"
8383

8484
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
8585
colocated_python_checkpointing: False
86+
87+
# enables autocheckpoint, which saves a checkpoint at the preemption step.
88+
enable_autocheckpoint: False
8689
############################### end checkpointing ##################################
8790

8891

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ class Checkpointing(BaseModel):
332332
False,
333333
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
334334
)
335+
enable_autocheckpoint: bool = Field(
336+
False, description="If True, enables autocheckpoint or preemption induced checkpointing."
337+
)
335338

336339

337340
class OrbaxStorage(BaseModel):

src/maxtext/utils/train_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting for training in MaxText. """
16+
"""Utils that are only interesting for training in MaxText."""
1717

1818
import os
1919
import jax
@@ -82,6 +82,7 @@ def create_training_tools(config, model, mesh):
8282
config.enable_single_controller,
8383
config.colocated_python_checkpointing,
8484
config.enable_single_replica_ckpt_restoring,
85+
config.enable_autocheckpoint,
8586
)
8687

8788
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)