Skip to content

Commit ff5ee4e

Browse files
Add autocheckpoint feature to MaxText.
PiperOrigin-RevId: 892064416
1 parent 4910293 commit ff5ee4e

5 files changed

Lines changed: 112 additions & 3 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def create_orbax_checkpoint_manager(
221221
enable_single_controller: bool = False,
222222
colocated_python_checkpointing: bool = False,
223223
enable_single_replica_ckpt_restoring: bool = False,
224+
enable_autocheckpoint: bool = False,
224225
):
225226
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
226227
if not enable_checkpointing:
@@ -248,11 +249,21 @@ def create_orbax_checkpoint_manager(
248249
# local storage checkpoint needs parent directory created
249250
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250251
if enable_continuous_checkpointing:
252+
max_logging.log("Enabling policy for continuous checkpointing.")
251253
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252-
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
254+
elif enable_autocheckpoint:
255+
max_logging.log("Enabling policy for autocheckpoint.")
256+
save_decision_policy = save_decision_policy_lib.AnySavePolicy(
257+
[
258+
save_decision_policy_lib.PreemptionCheckpointingPolicy(),
259+
save_decision_policy_lib.FixedIntervalPolicy(save_interval_steps),
260+
]
261+
)
253262
else:
263+
max_logging.log("Enabling policy for fixed interval checkpointing.")
254264
save_decision_policy = save_decision_policy_lib.FixedIntervalPolicy(interval=save_interval_steps)
255-
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
265+
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
266+
256267
async_options = None
257268
if enable_continuous_checkpointing:
258269
async_options = ocp.AsyncOptions(
@@ -752,6 +763,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
752763
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
753764
or (step % config.checkpoint_period == 0)
754765
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
766+
or (config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step))
755767
):
756768
blocking_until_ready_start = time.time()
757769
max_logging.log(f"Waiting for step {step} to finish before checkpoint...")

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ source_checkpoint_layout: "orbax"
8383

8484
# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing
8585
colocated_python_checkpointing: False
86+
87+
# enables autocheckpoint, which saves a checkpoint at the preemption step.
88+
enable_autocheckpoint: False
8689
############################### end checkpointing ##################################
8790

8891

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ class Checkpointing(BaseModel):
333333
False,
334334
description="If True, enables checkpointing from remote TPU VMs instead of head node on pathways.",
335335
)
336+
enable_autocheckpoint: bool = Field(
337+
False, description="If True, enables autocheckpoint or preemption induced checkpointing."
338+
)
336339

337340

338341
class OrbaxStorage(BaseModel):

src/maxtext/utils/train_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# pylint: disable=bare-except, consider-using-generator
16-
""" Utils that are only interesting for training in MaxText. """
16+
"""Utils that are only interesting for training in MaxText."""
1717

1818
import os
1919
import jax
@@ -82,6 +82,7 @@ def create_training_tools(config, model, mesh):
8282
config.enable_single_controller,
8383
config.colocated_python_checkpointing,
8484
config.enable_single_replica_ckpt_restoring,
85+
config.enable_autocheckpoint,
8586
)
8687

8788
return init_rng, checkpoint_manager, learning_rate_schedule, tx

tests/autocheckpoint_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for autocheckpoint feature."""
16+
17+
from unittest import mock
18+
from absl.testing import absltest
19+
from maxtext.common import checkpointing
20+
from maxtext.utils import exceptions
21+
import orbax.checkpoint as ocp
22+
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
23+
24+
25+
class AutocheckpointTest(absltest.TestCase):
26+
27+
@mock.patch("maxtext.src.maxtext.common.checkpointing.PyTreeCheckpointHandler")
28+
@mock.patch("maxtext.src.maxtext.common.checkpointing.CheckpointManager")
29+
@mock.patch("maxtext.src.maxtext.utils.gcs_utils.mkdir_and_check_permissions")
30+
def test_create_checkpoint_manager_with_autocheckpoint(self, mock_mkdir, mock_manager_cls, mock_handler_cls):
31+
mock_mkdir.return_value = "/tmp/checkpoint"
32+
33+
manager = checkpointing.create_orbax_checkpoint_manager(
34+
checkpoint_dir="/tmp/checkpoint",
35+
enable_checkpointing=True,
36+
use_async=False,
37+
save_interval_steps=100,
38+
enable_autocheckpoint=True,
39+
)
40+
41+
self.assertIsNotNone(manager)
42+
mock_manager_cls.assert_called_once()
43+
_, kwargs = mock_manager_cls.call_args
44+
options = kwargs.get("options")
45+
self.assertIsInstance(options.save_decision_policy, save_decision_policy_lib.AnySavePolicy)
46+
# AnySavePolicy internally has policies
47+
policies = options.save_decision_policy.policies
48+
self.assertTrue(any(isinstance(p, save_decision_policy_lib.PreemptionCheckpointingPolicy) for p in policies))
49+
self.assertTrue(any(isinstance(p, save_decision_policy_lib.FixedIntervalPolicy) for p in policies))
50+
51+
@mock.patch("jax.block_until_ready")
52+
@mock.patch("maxtext.src.maxtext.common.checkpointing.ocp.args.PyTreeSave")
53+
def test_save_checkpoint_triggers_on_preemption(self, mock_save_args, mock_block):
54+
mock_manager = mock.MagicMock(spec=ocp.CheckpointManager)
55+
mock_manager.reached_preemption.return_value = True
56+
57+
config = mock.MagicMock()
58+
config.enable_checkpointing = True
59+
config.checkpoint_period = 1000
60+
config.enable_autocheckpoint = True
61+
config.checkpoint_storage_target_data_file_size_bytes = 1024
62+
63+
state = mock.MagicMock()
64+
65+
# Step 5 is not a multiple of checkpoint_period (1000), but reached_preemption is True
66+
checkpointing.save_checkpoint(mock_manager, 5, state, config=config)
67+
68+
mock_manager.save.assert_called_once()
69+
mock_manager.reached_preemption.assert_called_with(5)
70+
71+
def test_maybe_save_checkpoint_handles_preemption(self):
72+
mock_manager = mock.MagicMock(spec=ocp.CheckpointManager)
73+
mock_manager.reached_preemption.return_value = True
74+
75+
config = mock.MagicMock()
76+
config.checkpoint_period = 1000
77+
config.enable_autocheckpoint = True
78+
79+
state = mock.MagicMock()
80+
state.step = 6
81+
82+
with self.assertRaisesRegex(exceptions.StopTraining, "Job is preempted."):
83+
# step=None means it will use state.step - 1 = 5
84+
checkpointing.maybe_save_checkpoint(mock_manager, state, config, data_iterator=None, step=None)
85+
86+
mock_manager.wait_until_finished.assert_called_once()
87+
88+
89+
if __name__ == "__main__":
90+
absltest.main()

0 commit comments

Comments
 (0)