@@ -59,7 +59,9 @@ def test_checkpoint_strategy_k(
5959 (4 , 3 , 15 ),
6060 ],
6161)
62- def test_keep_every_k_steps_keeps_every_k_steps (k : int , num_recent_checkpoints_to_keep : int , num_steps : int ) -> None :
62+ def test_keep_every_k_strategy_has_no_unexpected_checkpoints (
63+ k : int , num_recent_checkpoints_to_keep : int , num_steps : int
64+ ) -> None :
6365 checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy (
6466 k = k , num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep
6567 )
@@ -83,6 +85,40 @@ def test_keep_every_k_steps_keeps_every_k_steps(k: int, num_recent_checkpoints_t
8385 assert ckpt .num_seen_steps_total % k == 0 or ckpt .num_seen_steps_total in last_checkpoints
8486
8587
88+ @pytest .mark .parametrize (
89+ "k, num_recent_checkpoints_to_keep, num_steps" ,
90+ [
91+ (3 , 2 , 11 ),
92+ (2 , 1 , 10 ),
93+ (4 , 3 , 15 ),
94+ ],
95+ )
96+ def test_keep_every_k_strategy_has_no_unexpected_deletions (
97+ k : int , num_recent_checkpoints_to_keep : int , num_steps : int
98+ ) -> None :
99+ checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy (
100+ k = k , num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep
101+ )
102+ training_progress = TrainingProgress (
103+ num_seen_steps_current_run = 0 ,
104+ num_seen_tokens_current_run = 0 ,
105+ num_target_steps = 20 ,
106+ num_target_tokens = 40 ,
107+ )
108+
109+ # Simulate training progress and checkpointing
110+ simulator = _CheckpointSavingSimulator ()
111+ for step in range (1 , num_steps + 1 ):
112+ training_progress .num_seen_steps_current_run = step
113+ checkpoint_instruction = checkpoint_strategy .get_checkpoint_instruction (training_progress = training_progress )
114+ simulator .simulate_training_step (training_progress , checkpoint_instruction )
115+
116+ for i in range (1 , num_steps + 1 ):
117+ # Check that checkpoints that are divisible by k or the most recent ones are not deleted.
118+ if i % k == 0 or i > num_steps - num_recent_checkpoints_to_keep :
119+ assert any (ckpt .num_seen_steps_total == i for ckpt in simulator .saved_checkpoints )
120+
121+
86122def test_keep_every_k_steps_checkpointing_strategy_invalid_arguments () -> None :
87123 with pytest .raises (AssertionError ):
88124 KeepEveryKStepsAndMMostRecentCheckpointingStrategy (k = 0 , num_recent_checkpoints_to_keep = 1 )
0 commit comments