|
| 1 | +import dataclasses |
| 2 | + |
1 | 3 | import pytest |
2 | 4 |
|
3 | | -from modalities.checkpointing.checkpoint_saving_strategies import SaveKMostRecentCheckpointsStrategy |
| 5 | +from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction |
| 6 | +from modalities.checkpointing.checkpoint_saving_strategies import ( |
| 7 | + KeepEveryKStepsAndMMostRecentCheckpointingStrategy, |
| 8 | + SaveKMostRecentCheckpointsStrategy, |
| 9 | +) |
4 | 10 | from modalities.training.training_progress import TrainingProgress |
5 | 11 |
|
6 | 12 |
|
@@ -43,3 +49,59 @@ def test_checkpoint_strategy_k( |
43 | 49 | if k != 0 and save_current: |
44 | 50 | training_progress.num_seen_steps_current_run = 100 |
45 | 51 | assert checkpoint_strategy.saved_step_checkpoints[0].num_seen_steps_current_run == num_seen_steps_current_run |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.parametrize( |
| 55 | + "k, num_recent_checkpoints_to_keep, num_steps", |
| 56 | + [ |
| 57 | + (3, 2, 11), |
| 58 | + (2, 1, 10), |
| 59 | + (4, 3, 15), |
| 60 | + ], |
| 61 | +) |
| 62 | +def test_keep_every_k_steps_keeps_every_k_steps(k: int, num_recent_checkpoints_to_keep: int, num_steps: int) -> None: |
| 63 | + checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy( |
| 64 | + k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep |
| 65 | + ) |
| 66 | + training_progress = TrainingProgress( |
| 67 | + num_seen_steps_current_run=0, |
| 68 | + num_seen_tokens_current_run=0, |
| 69 | + num_target_steps=20, |
| 70 | + num_target_tokens=40, |
| 71 | + ) |
| 72 | + |
| 73 | + # Simulate training progress and checkpointing |
| 74 | + simulator = _CheckpointSavingSimulator() |
| 75 | + for step in range(1, num_steps + 1): |
| 76 | + training_progress.num_seen_steps_current_run = step |
| 77 | + checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress) |
| 78 | + simulator.simulate_training_step(training_progress, checkpoint_instruction) |
| 79 | + |
| 80 | + for ckpt in simulator.saved_checkpoints: |
| 81 | + # Check that only checkpoints that are divisible by k or the most recent ones are kept. |
| 82 | + last_checkpoints = set(range(num_steps - num_recent_checkpoints_to_keep + 1, num_steps + 1)) |
| 83 | + assert ckpt.num_seen_steps_current_run % k == 0 or ckpt.num_seen_steps_current_run in last_checkpoints |
| 84 | + |
| 85 | + |
| 86 | +def test_keep_every_k_steps_checkpointing_strategy_invalid_arguments() -> None: |
| 87 | + with pytest.raises(AssertionError): |
| 88 | + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=0, num_recent_checkpoints_to_keep=1) |
| 89 | + with pytest.raises(AssertionError): |
| 90 | + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=-1, num_recent_checkpoints_to_keep=1) |
| 91 | + with pytest.raises(AssertionError): |
| 92 | + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=0) |
| 93 | + with pytest.raises(AssertionError): |
| 94 | + KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=-1) |
| 95 | + |
| 96 | + |
| 97 | +class _CheckpointSavingSimulator: |
| 98 | + def __init__(self): |
| 99 | + self.saved_checkpoints: list[TrainingProgress] = [] |
| 100 | + |
| 101 | + def simulate_training_step( |
| 102 | + self, training_progress: TrainingProgress, ckpt_instruction: CheckpointingInstruction |
| 103 | + ) -> None: |
| 104 | + if ckpt_instruction.save_current: |
| 105 | + self.saved_checkpoints.append(dataclasses.replace(training_progress)) |
| 106 | + for checkpoint_to_delete in ckpt_instruction.checkpoints_to_delete: |
| 107 | + self.saved_checkpoints = [cp for cp in self.saved_checkpoints if cp != checkpoint_to_delete] |
0 commit comments