Skip to content

Commit 1aa9b56

Browse files
BlueCrescentCopilot
andcommitted
feat(checkpointing): New saving strategy that keeps every k steps and additionally the n most recent checkpoints.
Co-authored-by: Copilot <copilot@github.com>
1 parent 4705675 commit 1aa9b56

5 files changed

Lines changed: 134 additions & 1 deletion

File tree

docs/components/components.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ The composed initializer supports seeded weight initialization for reproducibili
8888
| checkpoint_saving | default | [CheckpointSaving](../../src/modalities/checkpointing/checkpoint_saving.py)| [CheckpointSavingConfig](s../../src/modalities/config/config.py) | -- | Component for saving checkpoints based on a savig and execution strategy. |
8989
| checkpoint_saving_strategy | save_every_k_steps_checkpointing_strategy | [SaveEveryKStepsCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [SaveEveryKStepsCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving a checkpoint every k steps |
9090
| checkpoint_saving_strategy | save_k_most_recent_checkpoints_strategy | [SaveKMostRecentCheckpointsStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [SaveKMostRecentCheckpointsStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving only the last k checkpoints and deleting the previous ones |
91+
| checkpoint_saving_strategy | keep_every_k_steps_and_m_most_recent_checkpointing_strategy | [KeepEveryKStepsAndMMostRecentCheckpointingStrategy](../../src/modalities/checkpointing/checkpoint_saving_strategies.py)| [KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig](../../src/modalities/config/config.py) | [CheckpointSavingStrategyIF](../../src/modalities/checkpointing/checkpoint_saving_strategies.py) | Checkpointing strategy saving a checkpoint every k steps and keeping the m most recent checkpoints |
9192
| checkpoint_saving_execution | fsdp | [FSDPCheckpointSaving](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py)| [FSDPCheckpointSavingConfig](../../src/modalities/config/config.py) | [CheckpointSavingExecutionABC](../../src/modalities/checkpointing/checkpoint_saving_execution.py) | FSDPCheckpointSaving class for saving checkpoints of FSDP models and optimizers. |
9293
| checkpoint_loading | fsdp | [FSDPCheckpointLoading](../../src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py)| [FSDPCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading FSDP checkpoints|
9394
| checkpoint_loading | torch | [TorchCheckpointLoading](../../src/modalities/checkpointing/torch/torch_checkpoint_loading.py)| [TorchCheckpointLoadingConfig](../../src/modalities/config/config.py) | [CheckpointLoadingIF](../../src/modalities/checkpointing/checkpoint_loading.py) | Component for loading PyTorch checkpoints|

src/modalities/checkpointing/checkpoint_saving_strategies.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,60 @@ def get_checkpoint_instruction(
119119
"""
120120
save_current = training_progress.num_seen_steps_total % self.k == 0
121121
return CheckpointingInstruction(save_current=save_current, checkpoints_to_delete=[])
122+
123+
124+
class KeepEveryKStepsAndMMostRecentCheckpointingStrategy(CheckpointSavingStrategyIF):
125+
"""Strategy for keeping every k steps permanently and additionally the most recent checkpoints."""
126+
127+
def __init__(self, k: int, num_recent_checkpoints_to_keep: int = 2):
128+
"""
129+
Initializes the CheckpointSavingStrategy object.
130+
131+
Args:
132+
k (int): The interval of steps to keep.
133+
num_recent_checkpoints_to_keep (int, optional): The number of recent checkpoints to keep.
134+
This includes all checkpoints but only the ones not divisible by k will actually be deleted.
135+
Defaults to 2.
136+
137+
Returns:
138+
None
139+
"""
140+
super().__init__()
141+
self._k = k
142+
self._num_recent_checkpoints_to_keep = num_recent_checkpoints_to_keep
143+
self._saved_recent_checkpoints: list[TrainingProgress] = []
144+
assert self._k > 0, "k must be greater than 0"
145+
assert self._num_recent_checkpoints_to_keep >= 1, "num_recent_checkpoints_to_keep must be at least 1"
146+
147+
def get_checkpoint_instruction(
148+
self,
149+
training_progress: TrainingProgress,
150+
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
151+
early_stopping_criterion_fulfilled: bool = False,
152+
) -> CheckpointingInstruction:
153+
"""
154+
Returns a CheckpointingInstruction object.
155+
156+
Args:
157+
training_progress (TrainingProgress): The training progress.
158+
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
159+
The evaluation result. Defaults to None.
160+
early_stopping_criterion_fulfilled (bool, optional):
161+
Whether the early stopping criterion is fulfilled. Defaults to False.
162+
163+
Returns:
164+
CheckpointingInstruction: The checkpointing instruction object.
165+
"""
166+
self._saved_recent_checkpoints.append(dataclasses.replace(training_progress))
167+
checkpoints_to_delete, self._saved_recent_checkpoints = (
168+
(
169+
self._saved_recent_checkpoints[: -self._num_recent_checkpoints_to_keep],
170+
self._saved_recent_checkpoints[-self._num_recent_checkpoints_to_keep :],
171+
)
172+
if len(self._saved_recent_checkpoints) > self._num_recent_checkpoints_to_keep
173+
else ([], self._saved_recent_checkpoints)
174+
)
175+
# Do not delete checkpoints that are divisible by k.
176+
checkpoints_to_delete = [cp for cp in checkpoints_to_delete if cp.num_seen_steps_current_run % self._k != 0]
177+
178+
return CheckpointingInstruction(save_current=True, checkpoints_to_delete=checkpoints_to_delete)

src/modalities/config/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ class SaveKMostRecentCheckpointsStrategyConfig(BaseModel):
9292
k: Annotated[int, Field(strict=True, ge=-1)]
9393

9494

95+
class KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig(BaseModel):
96+
k: Annotated[int, Field(strict=True, gt=0)]
97+
num_recent_checkpoints_to_keep: Annotated[int, Field(strict=True, ge=1)] = 2
98+
99+
95100
class TorchCheckpointLoadingConfig(BaseModel):
96101
device: PydanticPytorchDeviceType
97102
precision: Optional[PrecisionEnum] = None

src/modalities/registry/components.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from modalities.checkpointing.checkpoint_saving import CheckpointSaving
1212
from modalities.checkpointing.checkpoint_saving_strategies import (
13+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy,
1314
SaveEveryKStepsCheckpointingStrategy,
1415
SaveKMostRecentCheckpointsStrategy,
1516
)
@@ -47,6 +48,7 @@
4748
GPT2LLMCollateFnConfig,
4849
GPT2MFUCalculatorConfig,
4950
GPT2ModelTPConfig,
51+
KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig,
5052
LinearLRSchedulerConfig,
5153
LinearWarmupCosineAnnealingLRSchedulerConfig,
5254
LLMDataLoaderConfig,
@@ -353,6 +355,12 @@ class ComponentEntity:
353355
SaveKMostRecentCheckpointsStrategy,
354356
SaveKMostRecentCheckpointsStrategyConfig,
355357
),
358+
ComponentEntity(
359+
"checkpoint_saving_strategy",
360+
"keep_every_k_steps_and_m_most_recent_checkpointing_strategy",
361+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy,
362+
KeepEveryKStepsAndMMostRecentCheckpointingStrategyConfig,
363+
),
356364
# checkpoint saving execution
357365
ComponentEntity("checkpoint_saving_execution", "fsdp1", FSDP1CheckpointSaving, FSDP1CheckpointSavingConfig),
358366
ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig),

tests/checkpointing/test_checkpoint_strategies.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
import dataclasses
2+
13
import pytest
24

3-
from modalities.checkpointing.checkpoint_saving_strategies import SaveKMostRecentCheckpointsStrategy
5+
from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction
6+
from modalities.checkpointing.checkpoint_saving_strategies import (
7+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy,
8+
SaveKMostRecentCheckpointsStrategy,
9+
)
410
from modalities.training.training_progress import TrainingProgress
511

612

@@ -43,3 +49,59 @@ def test_checkpoint_strategy_k(
4349
if k != 0 and save_current:
4450
training_progress.num_seen_steps_current_run = 100
4551
assert checkpoint_strategy.saved_step_checkpoints[0].num_seen_steps_current_run == num_seen_steps_current_run
52+
53+
54+
@pytest.mark.parametrize(
55+
"k, num_recent_checkpoints_to_keep, num_steps",
56+
[
57+
(3, 2, 11),
58+
(2, 1, 10),
59+
(4, 3, 15),
60+
],
61+
)
62+
def test_keep_every_k_steps_keeps_every_k_steps(k: int, num_recent_checkpoints_to_keep: int, num_steps: int) -> None:
63+
checkpoint_strategy = KeepEveryKStepsAndMMostRecentCheckpointingStrategy(
64+
k=k, num_recent_checkpoints_to_keep=num_recent_checkpoints_to_keep
65+
)
66+
training_progress = TrainingProgress(
67+
num_seen_steps_current_run=0,
68+
num_seen_tokens_current_run=0,
69+
num_target_steps=20,
70+
num_target_tokens=40,
71+
)
72+
73+
# Simulate training progress and checkpointing
74+
simulator = _CheckpointSavingSimulator()
75+
for step in range(1, num_steps + 1):
76+
training_progress.num_seen_steps_current_run = step
77+
checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress)
78+
simulator.simulate_training_step(training_progress, checkpoint_instruction)
79+
80+
for ckpt in simulator.saved_checkpoints:
81+
# Check that only checkpoints that are divisible by k or the most recent ones are kept.
82+
last_checkpoints = set(range(num_steps - num_recent_checkpoints_to_keep + 1, num_steps + 1))
83+
assert ckpt.num_seen_steps_current_run % k == 0 or ckpt.num_seen_steps_current_run in last_checkpoints
84+
85+
86+
def test_keep_every_k_steps_checkpointing_strategy_invalid_arguments() -> None:
87+
with pytest.raises(AssertionError):
88+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=0, num_recent_checkpoints_to_keep=1)
89+
with pytest.raises(AssertionError):
90+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=-1, num_recent_checkpoints_to_keep=1)
91+
with pytest.raises(AssertionError):
92+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=0)
93+
with pytest.raises(AssertionError):
94+
KeepEveryKStepsAndMMostRecentCheckpointingStrategy(k=2, num_recent_checkpoints_to_keep=-1)
95+
96+
97+
class _CheckpointSavingSimulator:
98+
def __init__(self):
99+
self.saved_checkpoints: list[TrainingProgress] = []
100+
101+
def simulate_training_step(
102+
self, training_progress: TrainingProgress, ckpt_instruction: CheckpointingInstruction
103+
) -> None:
104+
if ckpt_instruction.save_current:
105+
self.saved_checkpoints.append(dataclasses.replace(training_progress))
106+
for checkpoint_to_delete in ckpt_instruction.checkpoints_to_delete:
107+
self.saved_checkpoints = [cp for cp in self.saved_checkpoints if cp != checkpoint_to_delete]

0 commit comments

Comments
 (0)