Fix clean runs not always saving on last step#129
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This PR ensures a training run that finishes cleanly always persists a final checkpoint at max_steps even when max_steps does not land on the configured checkpoint schedule, aligning clean completion behavior with the existing preemption/emergency-checkpoint guarantee.
Changes:
- Add an unconditional final checkpoint save in the training teardown path when completion is clean and
max_stepsis off-schedule. - Add an end-to-end test that runs an off-schedule clean completion and asserts
step_{max_steps}exists andlatestpoints to it. - Update training-loop documentation and the changelog to describe the new final-checkpoint guarantee.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| scripts/train.py | Adds a “clean completion” marker and a guarded unconditional final checkpoint save in teardown. |
| tests/e2e/test_training_e2e.py | Adds an e2e regression test asserting an off-schedule clean completion still writes step_{max_steps} and updates latest. |
| docs/training/training-loop.md | Documents the final-checkpoint-on-clean-completion behavior in the shutdown section. |
| CHANGELOG.md | Adds an Unreleased “Fixed” entry describing the new final checkpoint guarantee. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
amazloumi
left a comment
There was a problem hiding this comment.
Approving,
Deferring to a follow-up (not blocking this PR): the fix's decision lives inside the ~1000-line main() in scripts/train.py, so today it's only reachable via the --e2e tier (which doesn't run in default CI). Extracting the training loop into its own function/class (run_training_loop(...) or a small Trainer) would create a seam to inject a fake CheckpointManager and unit-test the save decision directly — fast, and in default CI. Since that's a refactor of the central loop, it deserves its own PR + review rather than expanding this one. I'll open an issue.
potential tests in the above followup:
- Clean finish, off-schedule step → save once at max_steps (on_checkpoint_save hook)
- Clean finish, on-schedule step → no second save (dedup)
- NaN on the last step → no final save (intentional — skipped optimizer step, weights are step-(N−1))
- NCCL health failure on the last step → no final save
- Graceful shutdown on the last step → no final save (emergency save already covers it)
- Resume from a checkpoint already at max_steps → no save and no crash (the ckpt_extra unbound case)
Minor:
- Copilot's NaN comment: right about the mechanism, but the behavior is intentional resolving as intended. Worth expanding the code comment to state why (skipped optimizer step → step-(N−1) weights) so it isn't "fixed" later.
- should_shutdown() isn't all-reduced across ranks, so a preemption on the final step could desync the emergency save. Pre-existing — flagging for a separate issue.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
amazloumi
left a comment
There was a problem hiding this comment.
Just changelog saying "no existing checkpoint is … pruned" isn't accurate. The final save participates in keep_last_n rotation like any scheduled save and cause one of the oldest ckpt to be cleaned up to respect the keep_last_n rule.
a88270e
refs: KEM-542
Summary
max_stepshappened to land on the save schedule. The per-step gateconfig.checkpoint.should_save(step)runs inside the loop on the post-increment step; after the loop onlyckpt_mgr.wait()runs, which flushes an in-flight async save but never starts a new one. So whenmax_stepsis neither a multiple ofcheckpoint.intervalnor adyn_ckpt_windowmilestone, the fully-trained model — including the entire WSD learning-rate decay tail — was never persisted, andlatestresolved to the highest scheduled step ≤max_steps. This was inconsistent with the preemption path, which already saves an emergency checkpoint at the current step on shutdown.scripts/train.py). Add one unconditional checkpoint atmax_stepsin the training-loop teardown, guarded so it fires only on a clean, off-schedule finish — giving clean completion the same guarantee as the preemption path. Reuses the existingCheckpointManager.saveand theon_checkpoint_savehook; thesave→waitorder is preserved (correct for both sync and async modes).completed_normallyflag is initializedFalsebefore the loop and setTrueonly as the last statement of the loop body whenstep >= tc.max_steps. That line is reachable only on the iteration that reachesmax_stepswith nobreak, so it excludes all early exits — NaN rollback, NCCL-health-check failure, and graceful shutdown — and is never reached on a zero-iteration resume, which also guaranteesckpt_extrais bound when the final save runs. The save is additionally gated onnot config.checkpoint.should_save(step)to avoid a duplicate whenmax_stepsalready coincided with the schedule.interval,keep_last_n,dyn_ckpt_window) are untouched. The only behavioral change is that an off-schedule final step now also produces a checkpoint.docs/training/training-loop.md). The## Shutdownsection now documents the final-checkpoint guarantee and shows it in the abridged teardown snippet.CHANGELOG.md). Entry added under### Fixedin## [Unreleased].tests/e2e/test_training_e2e.py).test_final_checkpoint_on_clean_completionruns a single-process training run to an off-schedule completion (--train.max_steps=10 --checkpoint.interval=1000, anddebug.tomlconfigures nodyn_ckpt_window, so no scheduled save lands) and asserts on the filesystem thatstep_10/exists andlatestresolves to it.Testing
uv run ruff check kempnerforge/ tests/passesuv run ruff format --check kempnerforge/ tests/ scripts/passesuv run pyright kempnerforge/passes (0 errors)uv run pytest tests/unit/ -v --timeout=60passesNOTE I can currently only access 1 GPU. 20 distributed tests were skipped; the rest all pass.
uv run pytest tests/e2e/ --e2e -v— applies; new test:tests/e2e/test_training_e2e.py::test_final_checkpoint_on_clean_completionCloses #125