Skip to content

Commit 33a8192

Browse files
BlueCrescentCopilot
andcommitted
test(checkpointing): Test that KeepEveryKStepsAndMMostRecentCheckpointingStrategy does not delete any required checkpoints.
Co-authored-by: Copilot <copilot@github.com>
1 parent d732db9 commit 33a8192

1 file changed

Lines changed: 37 additions & 1 deletion

File tree

tests/checkpointing/test_checkpoint_strategies.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def test_checkpoint_strategy_k(
5959
(4, 3, 15),
6060
],
6161
)
62-
def test_keep_every_k_steps_keeps_every_k_steps(k: int, num_recent_checkpoints_to_keep: int, num_steps: int) -> None:
62+
def test_keep_every_k_strategy_has_no_unexpected_checkpoints(
63+
k: int, num_recent_checkpoints_to_keep: int, num_steps: int
64+
) -> None:
6365
checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy(
6466
k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep
6567
)
@@ -83,6 +85,40 @@ def test_keep_every_k_steps_keeps_every_k_steps(k: int, num_recent_checkpoints_t
8385
assert ckpt.num_seen_steps_total % k == 0 or ckpt.num_seen_steps_total in last_checkpoints
8486

8587

88+
@pytest.mark.parametrize(
89+
"k, num_recent_checkpoints_to_keep, num_steps",
90+
[
91+
(3, 2, 11),
92+
(2, 1, 10),
93+
(4, 3, 15),
94+
],
95+
)
96+
def test_keep_every_k_strategy_has_no_unexpected_deletions(
97+
k: int, num_recent_checkpoints_to_keep: int, num_steps: int
98+
) -> None:
99+
checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy(
100+
k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep
101+
)
102+
training_progress = TrainingProgress(
103+
num_seen_steps_current_run=0,
104+
num_seen_tokens_current_run=0,
105+
num_target_steps=20,
106+
num_target_tokens=40,
107+
)
108+
109+
# Simulate training progress and checkpointing
110+
simulator = _CheckpointSavingSimulator()
111+
for step in range(1, num_steps + 1):
112+
training_progress.num_seen_steps_current_run = step
113+
checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress)
114+
simulator.simulate_training_step(training_progress, checkpoint_instruction)
115+
116+
for i in range(1, num_steps + 1):
117+
# Check that checkpoints that are divisible by k or the most recent ones are not deleted.
118+
if i % k == 0 or i > num_steps - num_recent_checkpoints_to_keep:
119+
assert any(ckpt.num_seen_steps_total == i for ckpt in simulator.saved_checkpoints)
120+
121+
86122
def test_keep_every_k_steps_checkpointing_strategy_invalid_arguments() -> None:
87123
with pytest.raises(AssertionError):
88124
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=0, num_recent_checkpoints_to_keep=1)

0 commit comments

Comments
 (0)