Wire max_checkpoints through SFT, DPO, and GRPO paths#1701
Conversation
Pass the existing keep_last_n_checkpoints config (default=3) to CheckpointerCallback's new max_checkpoints parameter across SFT, DPO, and GRPO training paths. Also adds keep_last_n_checkpoints to GRPOExperimentConfig (it was missing, unlike the SFT/DPO configs). Depends on allenai/OLMo-core#timd/add-max-checkpoints which adds the max_checkpoints parameter to CheckpointerCallback. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for limiting the number of saved checkpoints via a new keep_last_n_checkpoints configuration option, which is integrated across DPO, GRPO, and SFT training pipelines. The feedback suggests removing a redundant field definition in GRPOExperimentConfig since it is already inherited, and centralizing the mapping logic that converts negative values (representing unlimited checkpoints) to None inside build_checkpointer_callback to simplify and clean up the caller sites.
| keep_last_n_checkpoints: int = 3 | ||
| """Maximum number of permanent checkpoints to keep. -1 for unlimited.""" |
There was a problem hiding this comment.
| checkpointing_steps: int, | ||
| ephemeral_save_interval: int | None, | ||
| save_async: bool = True, | ||
| max_checkpoints: int | None = 3, | ||
| ) -> CheckpointerCallback: | ||
| """Construct a CheckpointerCallback with shared Open Instruct defaults.""" | ||
| return CheckpointerCallback( | ||
| save_interval=checkpointing_steps, ephemeral_save_interval=ephemeral_save_interval, save_async=save_async | ||
| save_interval=checkpointing_steps, | ||
| ephemeral_save_interval=ephemeral_save_interval, | ||
| save_async=save_async, | ||
| max_checkpoints=max_checkpoints, | ||
| ) |
There was a problem hiding this comment.
Instead of repeating the ternary operator args.keep_last_n_checkpoints if args.keep_last_n_checkpoints >= 0 else None across multiple caller files (dpo.py, grpo_olmo_core_actor.py, olmo_core_finetune.py), we can centralize this mapping logic inside build_checkpointer_callback. This keeps the callers clean and ensures consistent handling of the -1 (unlimited) convention.
def build_checkpointer_callback(
checkpointing_steps: int,
ephemeral_save_interval: int | None,
save_async: bool = True,
max_checkpoints: int | None = 3,
) -> CheckpointerCallback:
"""Construct a CheckpointerCallback with shared Open Instruct defaults."""
return CheckpointerCallback(
save_interval=checkpointing_steps,
ephemeral_save_interval=ephemeral_save_interval,
save_async=save_async,
max_checkpoints=max_checkpoints if (max_checkpoints is not None and max_checkpoints >= 0) else None,
)| wandb_project=args.wandb_project, | ||
| wandb_entity=args.wandb_entity, | ||
| save_async=False, | ||
| max_checkpoints=args.keep_last_n_checkpoints if args.keep_last_n_checkpoints >= 0 else None, |
There was a problem hiding this comment.
With the -1 to None mapping logic centralized inside build_checkpointer_callback, we can simplify this call by passing args.keep_last_n_checkpoints directly.
| max_checkpoints=args.keep_last_n_checkpoints if args.keep_last_n_checkpoints >= 0 else None, | |
| max_checkpoints=args.keep_last_n_checkpoints, |
| checkpointing_steps=self.grpo_config.checkpoint_state_freq, ephemeral_save_interval=None | ||
| checkpointing_steps=self.grpo_config.checkpoint_state_freq, | ||
| ephemeral_save_interval=None, | ||
| max_checkpoints=self.grpo_config.keep_last_n_checkpoints if self.grpo_config.keep_last_n_checkpoints >= 0 else None, |
There was a problem hiding this comment.
With the -1 to None mapping logic centralized inside build_checkpointer_callback, we can simplify this call by passing self.grpo_config.keep_last_n_checkpoints directly.
| max_checkpoints=self.grpo_config.keep_last_n_checkpoints if self.grpo_config.keep_last_n_checkpoints >= 0 else None, | |
| max_checkpoints=self.grpo_config.keep_last_n_checkpoints, |
| with_tracking=args.logging.with_tracking, | ||
| wandb_project=args.logging.wandb_project, | ||
| wandb_entity=args.logging.wandb_entity or "ai2-llm", | ||
| max_checkpoints=args.checkpoint.keep_last_n_checkpoints if args.checkpoint.keep_last_n_checkpoints >= 0 else None, |
There was a problem hiding this comment.
With the -1 to None mapping logic centralized inside build_checkpointer_callback, we can simplify this call by passing args.checkpoint.keep_last_n_checkpoints directly.
| max_checkpoints=args.checkpoint.keep_last_n_checkpoints if args.checkpoint.keep_last_n_checkpoints >= 0 else None, | |
| max_checkpoints=args.checkpoint.keep_last_n_checkpoints, |
… field - Move -1 to None conversion into build_checkpointer_callback - Remove redundant keep_last_n_checkpoints from GRPOExperimentConfig (inherited from CheckpointConfig) - Simplify all call sites to pass value directly - Add CHANGELOG entry Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Looks good to me. I guess wait until allenai/OLMo-core#694 has been merged and update olmo-core version in Also consider keeping the name |
Summary
keep_last_n_checkpoints(default: 3, fromCheckpointConfig) toCheckpointerCallback.max_checkpointsacross all OLMo-core training paths (SFT, DPO, GRPO)-1 → None(unlimited) convention insidebuild_checkpointer_callbackso call sites pass the value directlyGRPOExperimentConfigalready inheritskeep_last_n_checkpointsfromCheckpointConfig— no new config fields neededDepends on
max_checkpointstoCheckpointerCallback)Files changed
olmo_core_utils.py—build_checkpointer_callbackandbuild_base_callbacksaccept and pass throughmax_checkpoints;-1mapped toNoneinside the builderdpo.py— passesargs.keep_last_n_checkpointsthrougholmo_core_finetune.py— passesargs.checkpoint.keep_last_n_checkpointsthroughgrpo_olmo_core_actor.py— passesself.grpo_config.keep_last_n_checkpointsthroughTest plan
keep_last_n_checkpoints=3trims checkpoints--keep_last_n_checkpoints -1keeps all checkpoints