Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import StopAfterOneEpoch
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
from forge.util.checkpoint import warn_if_resuming_from_existing_folder
from forge.util.config import parse
from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -139,6 +140,15 @@ async def setup(self):

# TODO: confirm that this is working properly
# Should also use load, not dcp_load
ckpt_cfg = self.job_config.checkpoint
warn_if_resuming_from_existing_folder(
folder=ckpt_cfg.get("folder") if hasattr(ckpt_cfg, "get") else getattr(ckpt_cfg, "folder", None),
initial_load_path=(
ckpt_cfg.get("initial_load_path")
if hasattr(ckpt_cfg, "get")
else getattr(ckpt_cfg, "initial_load_path", None)
),
)
self.checkpointer.load(step=self.current_step)

# self.profiler = self.setup_profiler(self.train_config.profiler_config)
Expand Down
5 changes: 5 additions & 0 deletions src/forge/actors/trainer/titan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from forge.observability.perf_tracker import Tracer
from forge.rl.loss import create_shifted_targets
from forge.types import TrainBatch
from forge.util.checkpoint import warn_if_resuming_from_existing_folder
from monarch.actor import endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
Expand Down Expand Up @@ -115,6 +116,10 @@ async def setup(self):
}:
engine_config.pop(key) # Not part of job config
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
warn_if_resuming_from_existing_folder(
folder=getattr(self.checkpoint, "folder", None),
initial_load_path=getattr(self.checkpoint, "initial_load_path", None),
)
self.engine.checkpointer.load(step=self.step)
self.engine.optimizers.zero_grad()

Expand Down
64 changes: 64 additions & 0 deletions src/forge/util/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import time

import torchstore as ts
from forge.actors._torchstore_utils import get_param_prefix

logger = logging.getLogger(__name__)


async def drop_weights(version: int):
print(f"Dropping weights @ version {version}")
Expand All @@ -21,3 +25,63 @@ async def drop_weights(version: int):
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")


def warn_if_resuming_from_existing_folder(
folder: str | None, initial_load_path: str | None = None
) -> bool:
"""Logs a loud WARNING when the checkpointer is about to silently resume
from an existing ``checkpoint.folder``.

Torchtitan's checkpointer treats ``folder`` as the source of truth: if it
already contains saved step directories (``step-N``), it loads from there
and ignores ``initial_load_path``. Users running back-to-back experiments
without clearing the folder hit this footgun (see #631) — training
silently picks up where the prior run left off instead of starting from
the configured base model.

This helper logs once before the load happens so the resume is visible
in the standard training logs. Returns ``True`` when a warning was
emitted, so callers can also surface it through other channels (e.g. an
extra console banner) if they want.
"""
if not folder or not os.path.isdir(folder):
return False

try:
entries = os.listdir(folder)
except OSError as exc:
logger.debug("could not list checkpoint folder %s: %s", folder, exc)
return False

def _step_number(entry: str) -> int:
try:
return int(entry.removeprefix("step-").split("-", 1)[0])
except ValueError:
return -1

step_dirs = [
entry
for entry in entries
if entry.startswith("step-")
and os.path.isdir(os.path.join(folder, entry))
]
step_dirs.sort(key=_step_number)
if not step_dirs:
return False

extra = ""
if initial_load_path:
extra = (
f" Configured initial_load_path={initial_load_path!r} will be ignored "
"until the folder is cleared or renamed."
)
logger.warning(
"Resuming training from existing checkpoint folder %r (found %d saved "
"step dir(s); latest: %s).%s",
folder,
len(step_dirs),
step_dirs[-1],
extra,
)
return True
96 changes: 96 additions & 0 deletions tests/unit_tests/util/test_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for forge.util.checkpoint.warn_if_resuming_from_existing_folder
(regression for issue #631).

#631: torchtitan's checkpointer silently resumes from ``checkpoint.folder``
when it already contains saved step directories, ignoring
``initial_load_path``. Users running back-to-back experiments hit this
footgun without realizing the second run isn't starting from the base
model. The helper logs a loud WARNING right before the load so the resume
shows up in the standard training logs.
"""

import logging
import os

import pytest

from forge.util.checkpoint import warn_if_resuming_from_existing_folder


class TestWarnIfResumingFromExistingFolder:
def test_returns_false_when_folder_is_none(self):
assert warn_if_resuming_from_existing_folder(None) is False

def test_returns_false_when_folder_is_empty_string(self):
assert warn_if_resuming_from_existing_folder("") is False

def test_returns_false_when_folder_does_not_exist(self, tmp_path):
missing = tmp_path / "does_not_exist"
assert warn_if_resuming_from_existing_folder(str(missing)) is False

def test_returns_false_when_folder_has_no_step_dirs(self, tmp_path):
(tmp_path / "random_file.txt").write_text("noise")
(tmp_path / "logs").mkdir()
assert warn_if_resuming_from_existing_folder(str(tmp_path)) is False

def test_warns_when_step_dirs_exist(self, tmp_path, caplog):
(tmp_path / "step-100").mkdir()
(tmp_path / "step-200").mkdir()
(tmp_path / "step-50").mkdir()

with caplog.at_level(logging.WARNING, logger="forge.util.checkpoint"):
warned = warn_if_resuming_from_existing_folder(str(tmp_path))

assert warned is True
warning_records = [
r for r in caplog.records if r.levelno >= logging.WARNING
]
assert warning_records, "expected at least one WARNING-level log"
msg = warning_records[0].getMessage()
assert str(tmp_path) in msg
assert "step-200" in msg, "should report the latest step directory"
assert "3 saved step dir" in msg

def test_warning_mentions_ignored_initial_load_path(self, tmp_path, caplog):
(tmp_path / "step-1").mkdir()
initial = "hf://meta-llama/Meta-Llama-3.1-8B-Instruct"

with caplog.at_level(logging.WARNING, logger="forge.util.checkpoint"):
warn_if_resuming_from_existing_folder(
str(tmp_path), initial_load_path=initial
)

msg = caplog.records[-1].getMessage()
assert initial in msg
assert "will be ignored" in msg

def test_ignores_non_step_subdirs(self, tmp_path):
(tmp_path / "tensorboard").mkdir()
(tmp_path / "wandb").mkdir()
(tmp_path / "step-1-backup").mkdir() # starts with step- but not the step-N pattern? actually matches
# The helper currently treats anything starting with "step-" as a step
# dir; that's intentional — same prefix the checkpointer scans for.
assert warn_if_resuming_from_existing_folder(str(tmp_path)) is True

def test_handles_oserror_gracefully(self, tmp_path, monkeypatch, caplog):
"""If we can't list the folder (perms etc), don't crash, don't warn."""
(tmp_path / "step-1").mkdir()

def boom(_):
raise PermissionError("denied")

monkeypatch.setattr(os, "listdir", boom)
with caplog.at_level(logging.WARNING, logger="forge.util.checkpoint"):
warned = warn_if_resuming_from_existing_folder(str(tmp_path))

assert warned is False
warning_records = [
r for r in caplog.records if r.levelno >= logging.WARNING
]
assert not warning_records