|
12 | 12 | # k value is 0. No deletion of checkpoints. |
13 | 13 | (0, [], [], False), |
14 | 14 | # k value is 2, but there are currently only one checkpoint. Hence, no deletion. |
15 | | - (2, [1], [], True), |
| 15 | + (2, [TrainingProgress(1, 1, 20, 20)], [], True), |
16 | 16 | # k value is -1, therefore we want to keep all checkpoints without any deletion |
17 | 17 | ( |
18 | 18 | -1, |
|
25 | 25 | def test_checkpoint_strategy_k( |
26 | 26 | k: int, saved_instances: list[TrainingProgress], checkpoints_to_delete: list[int], save_current: bool |
27 | 27 | ) -> None: |
| 28 | + num_seen_steps_current_run = 10 |
28 | 29 | training_progress = TrainingProgress( |
29 | | - num_seen_steps_current_run=10, num_seen_tokens_current_run=10, num_target_steps=20, num_target_tokens=40 |
| 30 | + num_seen_steps_current_run=num_seen_steps_current_run, |
| 31 | + num_seen_tokens_current_run=10, |
| 32 | + num_target_steps=20, |
| 33 | + num_target_tokens=40, |
30 | 34 | ) |
31 | 35 | checkpoint_strategy = SaveKMostRecentCheckpointsStrategy(k=k) |
32 | 36 | checkpoint_strategy.saved_step_checkpoints = saved_instances |
33 | 37 | checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress) |
34 | 38 |
|
35 | 39 | assert checkpoint_instruction.checkpoints_to_delete == checkpoints_to_delete |
36 | 40 | assert checkpoint_instruction.save_current == save_current |
| 41 | + |
| 42 | + # make sure that modifying the training progress externally does not affect saved_step_checkpoints |
| 43 | + if k != 0 and save_current: |
| 44 | + training_progress.num_seen_steps_current_run = 100 |
| 45 | + assert checkpoint_strategy.saved_step_checkpoints[0].num_seen_steps_current_run == num_seen_steps_current_run |
0 commit comments