Skip to content

Commit 7c7e44c

Browse files
BlueCrescentCopilot
andcommitted
test(checkpointing): Added num_seen_steps_previous_run to keep every k tests.
Co-authored-by: Copilot <copilot@github.com>
1 parent 33a8192 commit 7c7e44c

1 file changed

Lines changed: 31 additions & 14 deletions

File tree

tests/checkpointing/test_checkpoint_strategies.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,19 @@ def test_checkpoint_strategy_k(
5252

5353

5454
@pytest.mark.parametrize(
55-
"k, num_recent_checkpoints_to_keep, num_steps",
55+
"k, num_recent_checkpoints_to_keep, num_steps, num_seen_steps_previous_run, num_seen_tokens_previous_run",
5656
[
57-
(3, 2, 11),
58-
(2, 1, 10),
59-
(4, 3, 15),
57+
(3, 2, 11, 0, 0),
58+
(2, 1, 10, 2, 4),
59+
(4, 3, 15, 3, 6),
6060
],
6161
)
6262
def test_keep_every_k_strategy_has_no_unexpected_checkpoints(
63-
k: int, num_recent_checkpoints_to_keep: int, num_steps: int
63+
k: int,
64+
num_recent_checkpoints_to_keep: int,
65+
num_steps: int,
66+
num_seen_steps_previous_run: int,
67+
num_seen_tokens_previous_run: int,
6468
) -> None:
6569
checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy(
6670
k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep
@@ -70,31 +74,42 @@ def test_keep_every_k_strategy_has_no_unexpected_checkpoints(
7074
num_seen_tokens_current_run=0,
7175
num_target_steps=20,
7276
num_target_tokens=40,
77+
num_seen_steps_previous_run=num_seen_steps_previous_run,
78+
num_seen_tokens_previous_run=num_seen_tokens_previous_run,
7379
)
7480

7581
# Simulate training progress and checkpointing
7682
simulator = _CheckpointSavingSimulator()
77-
for step in range(1, num_steps + 1):
83+
for step in range(num_steps + 1):
7884
training_progress.num_seen_steps_current_run = step
7985
checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress)
8086
simulator.simulate_training_step(training_progress, checkpoint_instruction)
8187

8288
for ckpt in simulator.saved_checkpoints:
8389
# Check that only checkpoints that are divisible by k or the most recent ones are kept.
84-
last_checkpoints = set(range(num_steps - num_recent_checkpoints_to_keep + 1, num_steps + 1))
90+
last_checkpoints = set(
91+
range(
92+
num_seen_steps_previous_run + num_steps - num_recent_checkpoints_to_keep + 1,
93+
num_seen_steps_previous_run + num_steps + 1,
94+
)
95+
)
8596
assert ckpt.num_seen_steps_total % k == 0 or ckpt.num_seen_steps_total in last_checkpoints
8697

8798

8899
@pytest.mark.parametrize(
89-
"k, num_recent_checkpoints_to_keep, num_steps",
100+
"k, num_recent_checkpoints_to_keep, num_steps, num_seen_steps_previous_run, num_seen_tokens_previous_run",
90101
[
91-
(3, 2, 11),
92-
(2, 1, 10),
93-
(4, 3, 15),
102+
(3, 2, 11, 0, 0),
103+
(2, 1, 10, 2, 4),
104+
(4, 3, 15, 3, 6),
94105
],
95106
)
96107
def test_keep_every_k_strategy_has_no_unexpected_deletions(
97-
k: int, num_recent_checkpoints_to_keep: int, num_steps: int
108+
k: int,
109+
num_recent_checkpoints_to_keep: int,
110+
num_steps: int,
111+
num_seen_steps_previous_run: int,
112+
num_seen_tokens_previous_run: int,
98113
) -> None:
99114
checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy(
100115
k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep
@@ -104,6 +119,8 @@ def test_keep_every_k_strategy_has_no_unexpected_deletions(
104119
num_seen_tokens_current_run=0,
105120
num_target_steps=20,
106121
num_target_tokens=40,
122+
num_seen_steps_previous_run=num_seen_steps_previous_run,
123+
num_seen_tokens_previous_run=num_seen_tokens_previous_run,
107124
)
108125

109126
# Simulate training progress and checkpointing
@@ -113,9 +130,9 @@ def test_keep_every_k_strategy_has_no_unexpected_deletions(
113130
checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress)
114131
simulator.simulate_training_step(training_progress, checkpoint_instruction)
115132

116-
for i in range(1, num_steps + 1):
133+
for i in range(num_seen_steps_previous_run + 1, num_seen_steps_previous_run + num_steps + 1):
117134
# Check that checkpoints that are divisible by k or the most recent ones are not deleted.
118-
if i % k == 0 or i > num_steps - num_recent_checkpoints_to_keep:
135+
if i % k == 0 or i > num_seen_steps_previous_run + num_steps - num_recent_checkpoints_to_keep:
119136
assert any(ckpt.num_seen_steps_total == i for ckpt in simulator.saved_checkpoints)
120137

121138

0 commit comments

Comments
 (0)