diff --git a/CHANGELOG.md b/CHANGELOG.md index 28c7284..d0fff79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `docs/checkpointing/dcp-model.md`: updated the save/load snippets and the "shape to fill" explanation to the DCP-aware helpers. - Tests (fail on the pre-fix code, pass after): `tests/integration/test_checkpoint_roundtrip.py::test_manager_restores_optimizer_moments_single_gpu` and `tests/distributed/test_checkpoint.py::test_resume_restores_optimizer_moments` assert `exp_avg` / `exp_avg_sq` are restored bit-exactly into a *fresh* optimizer (single-GPU + distributed); `tests/e2e/test_training_e2e.py::test_resume_determinism_single_gpu` / `test_resume_determinism_2gpu_fsdp` assert end-to-end bit-exact loss across an interrupt-and-resume on a learnable dataset. - **On-disk format note:** optimizer state is now keyed by parameter fully-qualified name rather than positional index. Checkpoints written before this fix will not restore optimizer state on resume (training continues with a fresh optimizer); model state is unaffected. +- **Always save a final checkpoint on clean completion.** A run that finished cleanly previously wrote a final checkpoint only when `max_steps` happened to land on the save schedule; otherwise the fully-trained model (including the entire WSD decay tail) was never persisted and `latest` resolved to the last scheduled step. The training-loop teardown now writes one unconditional checkpoint at `max_steps` on normal completion when that step is off-schedule, matching the emergency-checkpoint guarantee on the preemption path. Purely additive: the schedule/retention knobs (`interval`, `keep_last_n`, `dyn_ckpt_window`) are untouched. + - `scripts/train.py`: a `completed_normally` marker set on the final iteration, plus an off-schedule final `ckpt_mgr.save(...)` (and `on_checkpoint_save` hook) before `ckpt_mgr.wait()`. + - Tests: `tests/e2e/test_training_e2e.py` — a clean off-schedule completion writes `step_{max_steps}` and points `latest` at it. ## [0.1.0] — 2026-04-16 diff --git a/docs/training/training-loop.md b/docs/training/training-loop.md index 641880d..e7eea1f 100644 --- a/docs/training/training-loop.md +++ b/docs/training/training-loop.md @@ -185,12 +185,22 @@ After the loop: ```python prof.stop() +# Clean off-schedule finish: persist the fully-trained final step +if completed_normally and not config.checkpoint.should_save(step): + ckpt_mgr.save(step, ...) # final checkpoint; `latest` committed after wait ckpt_mgr.wait() # drain last async save hook_runner.on_train_end(step, tokens_seen) tracker.close() destroy_distributed() ``` +On a clean finish, an unconditional checkpoint is written at `max_steps` +when that step is not already on the save schedule — so a completed run's +fully-trained model (including the WSD decay tail) is always recoverable +and `latest` points at it. This mirrors the emergency checkpoint the +preemption path writes on shutdown. The `should_save` guard avoids a +duplicate when `max_steps` already coincided with the schedule. + `ckpt_mgr.wait()` is load-bearing — without it, a rank can exit before its async DCP write completes, corrupting the checkpoint for everyone else on the same save. See diff --git a/scripts/train.py b/scripts/train.py index 54e6f3c..40e8358 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -672,6 +672,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: ) hook_runner.on_checkpoint_save(0, config.checkpoint.dir) + completed_normally = False + while step < tc.max_steps: # Refresh data iterator at start / epoch boundary if dataloader is not None and data_iter is None: @@ -1039,11 +1041,29 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: shutdown_handler.finish() break + # Clean-completion marker for the unconditional final save after the + # loop. Only reached when training completes without any errors, e.g., + # no NaN/NCCL/shutdown breaks. If a run encounters a NaN, the last step + # is intentionally *not* saved because the actual model state would be + # `max_steps - 1`, not `max_steps`. + if step >= tc.max_steps: + completed_normally = True + if prof is not None: prof.stop() if rank == 0: print_profiler_summary(prof, trace_dir=config.profiling.trace_dir) + if completed_normally and not config.checkpoint.should_save(step): + ckpt_mgr.save( + step=step, + tokens_seen=tokens_seen, + scheduler=scheduler, + dataloader=dataloader, + extra=ckpt_extra, + ) + hook_runner.on_checkpoint_save(step, config.checkpoint.dir) + # Flush any pending async checkpoint before tearing down process group ckpt_mgr.wait() diff --git a/tests/e2e/test_training_e2e.py b/tests/e2e/test_training_e2e.py index 86983bc..875108f 100644 --- a/tests/e2e/test_training_e2e.py +++ b/tests/e2e/test_training_e2e.py @@ -606,6 +606,37 @@ def test_sigterm_triggers_emergency_checkpoint(tmp_path): assert "Checkpoint saved" in output, f"Emergency checkpoint was not saved:\n{output[-2000:]}" +@pytest.mark.e2e +def test_final_checkpoint_on_clean_completion(tmp_path): + """A clean finish at an off-schedule step must still persist a final + checkpoint, with `latest` pointing at it.""" + ckpt_dir = tmp_path / "final_ckpt" + + # max_steps=10 with interval=1000 => no scheduled save lands (debug.toml + # configures no dyn_ckpt_window), so the only checkpoint must be the + # unconditional final save at step 10. + result = _run_training( + [ + DEBUG_CONFIG, + "--train.max_steps=10", + "--metrics.log_interval=5", + f"--checkpoint.dir={ckpt_dir}", + "--checkpoint.interval=1000", # no scheduled checkpoint + ], + nproc=1, + ) + _assert_training_complete(result, expected_steps=10) + + output = result.stdout + result.stderr + final = ckpt_dir / "step_10" + latest = ckpt_dir / "latest" + assert final.is_dir(), f"final checkpoint step_10 was not written:\n{output[-2000:]}" + assert latest.exists(), f"`latest` pointer is missing:\n{output[-2000:]}" + assert latest.resolve().name == "step_10", ( + f"`latest` should resolve to step_10, got {latest.resolve().name!r}" + ) + + # ============================================================================ # MoE Training # ============================================================================