Skip to content

Commit 0625ec3

Browse files
Add elastic pause/resume functionality to MaxText.
PiperOrigin-RevId: 890594825
1 parent c30ada0 commit 0625ec3

5 files changed

Lines changed: 112 additions & 2 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,3 +1181,8 @@ distill_temperature: 1.0
11811181
# 0.0 value disables this feature.
11821182
distill_beta: 0.0
11831183
distill_layer_indices: None
1184+
1185+
1186+
##### Elastic training parameters
1187+
elastic_pause_resume: false
1188+
elastic_timeout: 300

src/maxtext/configs/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,15 @@ class Goodput(BaseModel):
15161516
enable_gcp_step_deviation_metrics: bool = Field(True, description="Enable GCP step deviation metrics.")
15171517

15181518

1519+
class ElasticTraining(BaseModel):
1520+
"""Configuration for elastic training and fault tolerance."""
1521+
1522+
elastic_pause_resume: bool = Field(
1523+
False, description="Whether to enable elastic pause and resume functionality."
1524+
)
1525+
elastic_timeout: int = Field(3600, description="The timeout in seconds for elastic training operations.")
1526+
1527+
15191528
class GcpMonitoring(BaseModel):
15201529
"""Configuration for GCP-specific workload monitoring."""
15211530

@@ -1897,6 +1906,7 @@ class MaxTextConfig(
18971906
Checkpointing,
18981907
OrbaxStorage,
18991908
EmergencyCheckpointing,
1909+
ElasticTraining,
19001910
# Data Types and Quantization
19011911
DataTypes,
19021912
Quantization,

src/maxtext/trainers/pre_train/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from maxtext.configs import pyconfig
4242
from maxtext.common.common_types import ShardMode
4343
from maxtext.utils.globals import EPS
44+
import maxtext.utils.elastic_utils as elastic_utils
4445
# Placeholder: internal
4546

4647
# pylint: disable=too-many-positional-arguments
@@ -679,9 +680,18 @@ def run(config, recorder, diagnostic_config):
679680
def main(argv: Sequence[str]) -> None:
680681
config, recorder, diagnostic_config = initialize(argv)
681682
record_goodput(recorder, RECORD_JOB_START_TIME)
682-
with maybe_monitor_goodput(config):
683+
684+
def train_func():
685+
config, recorder, diagnostic_config = initialize(argv)
683686
run(config, recorder, diagnostic_config)
684687

688+
if config.elastic_pause_resume:
689+
max_logging.log("Elastic Pause and Resume Enabled.")
690+
train_func = elastic_utils.elastic_pause_resume(config)(train_func)
691+
692+
with maybe_monitor_goodput(config):
693+
train_func()
694+
685695

686696
if __name__ == "__main__":
687697
app.run(main)

src/maxtext/utils/elastic_utils.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility functions for Elastic Training."""
16+
17+
import functools
18+
import re
19+
import subprocess
20+
21+
import maxtext.utils.max_logging as max_logging
22+
import pathwaysutils
23+
from pathwaysutils.elastic import manager
24+
25+
26+
elastic_manager: manager.Manager | None = None
27+
28+
29+
def elastic_mode_enabled(config) -> bool:
30+
"""Returns whether elastic mode is enabled."""
31+
return (pathwaysutils.is_pathways_backend_used() and
32+
config.elastic_pause_resume)
33+
34+
35+
def clean_up_checkpoints(checkpoint_dir: str):
36+
"""Cleans up incomplete checkpoints after an elastic event."""
37+
max_logging.log(f"Elastic utils: Checking for incomplete checkpoint after an elastic event...")
38+
checkpoint_dir = f"{checkpoint_dir}"
39+
40+
# 1. List the directory
41+
result = subprocess.run(['gsutil', 'ls', checkpoint_dir], capture_output=True, text=True)
42+
43+
if result.returncode != 0:
44+
max_logging.log("Failed to inspect checkpoint dir. Continuing")
45+
return
46+
47+
# 2. Filter for directories ending in numbers/ (equivalent to your grep and sort)
48+
checkpoints = [line for line in result.stdout.splitlines() if re.search(r'/\d+/$', line)]
49+
50+
if not checkpoints:
51+
max_logging.log("Found no existing checkpoints. Continuing")
52+
return
53+
54+
# Sort naturally (Version sort) and get the last one
55+
checkpoints.sort(key=lambda x: [int(c) if c.isdigit() else c for c in re.split(r'(\d+)', x)])
56+
latest_checkpoint = checkpoints[-1]
57+
58+
max_logging.log(f"Checking latest checkpoint: {latest_checkpoint}")
59+
60+
# 3. Check for commit_success file
61+
# gsutil -q stat returns 0 if found, non-zero if not
62+
stat_check = subprocess.run(['gsutil', '-q', 'stat', f"{latest_checkpoint}commit_success*"])
63+
64+
if stat_check.returncode != 0:
65+
max_logging.log(f"No commit_success file found. Deleting {latest_checkpoint}...")
66+
subprocess.run(['gsutil', '-m', 'rm', '-rf', latest_checkpoint])
67+
else:
68+
max_logging.log(f"Found commit_success file. Keeping {latest_checkpoint}.")
69+
70+
71+
def elastic_pause_resume(config, callback_fn=None):
72+
"""Pauses and resumes elastic training."""
73+
cleanup_partial = functools.partial(
74+
clean_up_checkpoints, config.checkpoint_dir
75+
)
76+
callback_fn = cleanup_partial if callback_fn is None else callback_fn
77+
return elastic_manager.elastic_retry(
78+
max_retries=10,
79+
poll_interval=10,
80+
timeout=config.elastic_timeout,
81+
on_elastic_event_callback=callback_fn,
82+
)
83+
84+

src/maxtext/utils/maxtext_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from maxtext.utils import gcs_utils
4545
from maxtext.utils import max_logging
4646
from maxtext.utils import max_utils
47+
from maxtext.utils import elastic_utils
4748
from maxtext.utils import sharding
4849

4950
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
@@ -1323,7 +1324,7 @@ def add_config_to_summary_writer(config, summary_writer):
13231324
def create_device_mesh(config, devices=None):
13241325
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
13251326
if devices is None:
1326-
devices = jax.devices()
1327+
devices = elastic_utils.live_devices() if not elastic_utils.elastic_mode_enabled(config) else jax.devices()
13271328
if config.subslice_shape and config.enable_single_controller and config.num_slices == 1:
13281329
max_logging.log(f"Trying to create a subslice with shape: {config.subslice_shape}")
13291330
subslice_shape = tuple(int(x) for x in config.subslice_shape.split(","))

0 commit comments

Comments
 (0)