Skip to content

Commit 66ddc48

Browse files
committed
fix: raise RuntimeError when checkpoint step >= config.steps
When a user sets steps=x and there is already a checkpoint saved at step x, the job should fail with a clear error message instead of performing no computation or failing with a confusing profiling error. We add an early check in setup_train_loop (train_utils.py) and a fallback check in train_loop (train.py) to fail fast before loading the checkpoint/initializing TPU or before the expensive TPU compilation step. Both checks are standardized to use a shared validation helper. Unit tests are added to verify the validation logic. TAG=agy CONV=88c01cb5-28b2-4b67-8895-4a290d332d3f
1 parent fe529ee commit 66ddc48

3 files changed

Lines changed: 47 additions & 5 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,9 @@ def train_loop(config, recorder, state=None):
639639
state,
640640
) = train_utils.setup_train_loop(config, recorder)
641641

642+
start_step = get_first_step(model, state) # this is the start_step for training
643+
train_utils.validate_completed_steps(start_step, config.steps)
644+
642645
if isinstance(model, nn.Module):
643646
if config.use_dpo:
644647
if "reference_params" not in state.params:
@@ -682,8 +685,6 @@ def train_loop(config, recorder, state=None):
682685
compiled = p_train_step.lower(*lower_args).compile(compiler_options=compiler_options)
683686
compiled_stats = compiled.memory_analysis()
684687
max_utils.print_compiled_memory_stats(compiled_stats)
685-
686-
start_step = get_first_step(model, state) # this is the start_step for training
687688
prof = profiler.Profiler(config, offset_step=start_step)
688689
metric_logger_instance = metric_logger.MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
689690

src/maxtext/utils/train_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def create_train_state_fn():
240240
else:
241241
init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng)
242242
checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn)
243+
if checkpoint_manager is not None:
244+
checkpoint_step = checkpoint_manager.latest_step()
245+
if checkpoint_step is not None:
246+
validate_completed_steps(checkpoint_step + 1, config.steps)
247+
243248

244249
with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
245250
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
@@ -405,3 +410,15 @@ def validate_train_config(config):
405410
"WARNING: Sequence packing is essentially ignored for synthetic data. "
406411
"Please use a real dataset to use sequence packing."
407412
)
413+
414+
415+
def validate_completed_steps(completed_steps: int, config_steps: int):
416+
"""Raises RuntimeError if training has already completed up to config_steps."""
417+
if completed_steps >= config_steps:
418+
raise RuntimeError(
419+
f"Requested training up to step {config_steps}, but a checkpoint already exists at step {completed_steps - 1} "
420+
f"(which means {completed_steps} steps have been completed). "
421+
f"Did you mean to continue training past step {completed_steps} (you should set steps > {completed_steps}) "
422+
f"or to not load the checkpoint (use enable_checkpointing=False?)"
423+
)
424+

tests/unit/train_utils_test.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
from dataclasses import dataclass
1919
from unittest.mock import MagicMock
2020

21-
from maxtext.utils.train_utils import validate_train_config, create_training_optimizer
21+
from maxtext.utils.train_utils import (
22+
validate_train_config,
23+
create_training_optimizer,
24+
validate_completed_steps,
25+
)
2226

2327

2428
@dataclass
@@ -185,12 +189,32 @@ def test_sgd_optimizer_returns_tx(self):
185189
config.learning_rate_schedule_steps = 100
186190
config.lr_schedule_type = "cosine"
187191
config.use_iota_embed = False
188-
189192
_, tx = create_training_optimizer(config, model=None)
190-
191193
self.assertIsNotNone(tx)
192194
self.assertTrue(hasattr(tx, "init"))
193195

194196

197+
class TestValidateCompletedSteps(unittest.TestCase):
198+
"""Tests for validate_completed_steps."""
199+
200+
def test_under_steps_passes(self):
201+
"""Verifies no exception raised when completed_steps < config_steps."""
202+
# Should not raise
203+
validate_completed_steps(completed_steps=50, config_steps=100)
204+
205+
def test_equal_steps_raises(self):
206+
"""Verifies RuntimeError raised when completed_steps == config_steps."""
207+
with self.assertRaises(RuntimeError) as context:
208+
validate_completed_steps(completed_steps=100, config_steps=100)
209+
self.assertIn("Requested training up to step 100, but a checkpoint already exists at step 99", str(context.exception))
210+
211+
def test_over_steps_raises(self):
212+
"""Verifies RuntimeError raised when completed_steps > config_steps."""
213+
with self.assertRaises(RuntimeError) as context:
214+
validate_completed_steps(completed_steps=105, config_steps=100)
215+
self.assertIn("Requested training up to step 100, but a checkpoint already exists at step 104", str(context.exception))
216+
217+
195218
if __name__ == "__main__":
196219
unittest.main()
220+

0 commit comments

Comments
 (0)