@@ -52,15 +52,19 @@ def test_checkpoint_strategy_k(
5252
5353
5454@pytest .mark .parametrize (
55- "k, num_recent_checkpoints_to_keep, num_steps" ,
55+ "k, num_recent_checkpoints_to_keep, num_steps, num_seen_steps_previous_run, num_seen_tokens_previous_run " ,
5656 [
57- (3 , 2 , 11 ),
58- (2 , 1 , 10 ),
59- (4 , 3 , 15 ),
57+ (3 , 2 , 11 , 0 , 0 ),
58+ (2 , 1 , 10 , 2 , 4 ),
59+ (4 , 3 , 15 , 3 , 6 ),
6060 ],
6161)
6262def test_keep_every_k_strategy_has_no_unexpected_checkpoints (
63- k : int , num_recent_checkpoints_to_keep : int , num_steps : int
63+ k : int ,
64+ num_recent_checkpoints_to_keep : int ,
65+ num_steps : int ,
66+ num_seen_steps_previous_run : int ,
67+ num_seen_tokens_previous_run : int ,
6468) -> None :
6569 checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy (
6670 k = k , num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep
@@ -70,31 +74,42 @@ def test_keep_every_k_strategy_has_no_unexpected_checkpoints(
7074 num_seen_tokens_current_run = 0 ,
7175 num_target_steps = 20 ,
7276 num_target_tokens = 40 ,
77+ num_seen_steps_previous_run = num_seen_steps_previous_run ,
78+ num_seen_tokens_previous_run = num_seen_tokens_previous_run ,
7379 )
7480
7581 # Simulate training progress and checkpointing
7682 simulator = _CheckpointSavingSimulator ()
77- for step in range (1 , num_steps + 1 ):
83+ for step in range (num_steps + 1 ):
7884 training_progress .num_seen_steps_current_run = step
7985 checkpoint_instruction = checkpoint_strategy .get_checkpoint_instruction (training_progress = training_progress )
8086 simulator .simulate_training_step (training_progress , checkpoint_instruction )
8187
8288 for ckpt in simulator .saved_checkpoints :
8389 # Check that only checkpoints that are divisible by k or the most recent ones are kept.
84- last_checkpoints = set (range (num_steps - num_recent_checkpoints_to_keep + 1 , num_steps + 1 ))
90+ last_checkpoints = set (
91+ range (
92+ num_seen_steps_previous_run + num_steps - num_recent_checkpoints_to_keep + 1 ,
93+ num_seen_steps_previous_run + num_steps + 1 ,
94+ )
95+ )
8596 assert ckpt .num_seen_steps_total % k == 0 or ckpt .num_seen_steps_total in last_checkpoints
8697
8798
8899@pytest .mark .parametrize (
89- "k, num_recent_checkpoints_to_keep, num_steps" ,
100+ "k, num_recent_checkpoints_to_keep, num_steps, num_seen_steps_previous_run, num_seen_tokens_previous_run " ,
90101 [
91- (3 , 2 , 11 ),
92- (2 , 1 , 10 ),
93- (4 , 3 , 15 ),
102+ (3 , 2 , 11 , 0 , 0 ),
103+ (2 , 1 , 10 , 2 , 4 ),
104+ (4 , 3 , 15 , 3 , 6 ),
94105 ],
95106)
96107def test_keep_every_k_strategy_has_no_unexpected_deletions (
97- k : int , num_recent_checkpoints_to_keep : int , num_steps : int
108+ k : int ,
109+ num_recent_checkpoints_to_keep : int ,
110+ num_steps : int ,
111+ num_seen_steps_previous_run : int ,
112+ num_seen_tokens_previous_run : int ,
98113) -> None :
99114 checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy (
100115 k = k , num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep
@@ -104,6 +119,8 @@ def test_keep_every_k_strategy_has_no_unexpected_deletions(
104119 num_seen_tokens_current_run = 0 ,
105120 num_target_steps = 20 ,
106121 num_target_tokens = 40 ,
122+ num_seen_steps_previous_run = num_seen_steps_previous_run ,
123+ num_seen_tokens_previous_run = num_seen_tokens_previous_run ,
107124 )
108125
109126 # Simulate training progress and checkpointing
@@ -113,9 +130,9 @@ def test_keep_every_k_strategy_has_no_unexpected_deletions(
113130 checkpoint_instruction = checkpoint_strategy .get_checkpoint_instruction (training_progress = training_progress )
114131 simulator .simulate_training_step (training_progress , checkpoint_instruction )
115132
116- for i in range (1 , num_steps + 1 ):
133+ for i in range (num_seen_steps_previous_run + 1 , num_seen_steps_previous_run + num_steps + 1 ):
117134 # Check that checkpoints that are divisible by k or the most recent ones are not deleted.
118- if i % k == 0 or i > num_steps - num_recent_checkpoints_to_keep :
135+ if i % k == 0 or i > num_seen_steps_previous_run + num_steps - num_recent_checkpoints_to_keep :
119136 assert any (ckpt .num_seen_steps_total == i for ckpt in simulator .saved_checkpoints )
120137
121138
0 commit comments