Skip to content

Commit 6b3df7f

Browse files
authored
Merge pull request #52 from KempnerInstitute/dataloader-state-persistence
Persist dataloader state across checkpoints
2 parents 7ba26d6 + 0245adb commit 6b3df7f

3 files changed

Lines changed: 208 additions & 0 deletions

File tree

kempnerforge/checkpoint/manager.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def __init__(
6060
self._async_ckpt = AsyncCheckpointer(mode=config.async_mode)
6161
self._process_group = process_group
6262
self._pp_rank = pp_rank
63+
# Dataloader state stashed during load() when the caller cannot yet
64+
# provide a dataloader object. Applied later via
65+
# apply_dataloader_state() once the loader is constructed.
66+
self._pending_dataloader_state: dict[str, Any] | None = None
6367

6468
def _checkpoint_dir(self, step: int) -> Path:
6569
return self.base_dir / f"step_{step}"
@@ -212,6 +216,11 @@ def load(
212216
train_state = object_list[0]
213217

214218
assert train_state is not None, "train_state broadcast failed"
219+
# Stash dataloader state if the caller can't yet provide the loader
220+
# object. Training loops construct the dataloader after load() so
221+
# apply_dataloader_state() can restore it once it exists.
222+
if dataloader is None and "dataloader" in train_state:
223+
self._pending_dataloader_state = train_state["dataloader"]
215224
step, tokens_seen, extra = restore_train_state(
216225
train_state,
217226
scheduler=scheduler,
@@ -222,6 +231,25 @@ def load(
222231

223232
return 0, 0, {}
224233

234+
def apply_dataloader_state(self, dataloader: Any) -> None:
235+
"""Apply any dataloader state stashed during load().
236+
237+
Training loops call load() before constructing the dataloader (since
238+
the dataloader depends on phase/annealing state that load() restores).
239+
This method applies the stashed state once the loader exists.
240+
241+
No-op if no state is pending, or if the loader does not support
242+
``load_state_dict`` (e.g., plain torch DataLoader for HF streaming).
243+
"""
244+
if self._pending_dataloader_state is None:
245+
return
246+
if dataloader is None or not hasattr(dataloader, "load_state_dict"):
247+
self._pending_dataloader_state = None
248+
return
249+
dataloader.load_state_dict(self._pending_dataloader_state)
250+
self._pending_dataloader_state = None
251+
logger.info("Applied stashed dataloader state")
252+
225253
def _resolve_load_path(self, path: str | None = None) -> Path | None:
226254
"""Resolve the checkpoint path to load from."""
227255
if path is not None:

scripts/train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,13 @@ def main() -> None:
376376
f"{config.data.hf_dataset_name} ({config.data.hf_dataset_split})"
377377
)
378378

379+
# Apply any dataloader state stashed during load(). Runs after dataloader
380+
# construction because the loader's identity depends on phase scheduling
381+
# that load() restores. No-op when resuming without a prior dataloader
382+
# state or when the loader is not stateful (plain TorchDataLoader).
383+
if dataloader is not None:
384+
ckpt_mgr.apply_dataloader_state(dataloader)
385+
379386
# --- Eval data ---
380387
eval_config = config.eval
381388
eval_dataloader = None
@@ -763,6 +770,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
763770
step=step,
764771
tokens_seen=tokens_seen,
765772
scheduler=scheduler,
773+
dataloader=dataloader,
766774
extra=ckpt_extra,
767775
)
768776
hook_runner.on_checkpoint_save(step, config.checkpoint.dir)
@@ -774,6 +782,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
774782
step=step,
775783
tokens_seen=tokens_seen,
776784
scheduler=scheduler,
785+
dataloader=dataloader,
777786
extra=ckpt_extra,
778787
)
779788
shutdown_handler.finish()

tests/unit/test_checkpoint.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,174 @@ def test_save_waits_for_previous(self, monkeypatch, tmp_path):
254254
ckpt.save({"model": {}}, checkpoint_id=str(tmp_path / "step_2"))
255255
# First future should have been waited on before second save
256256
mock_future1.result.assert_called_once()
257+
258+
259+
# ---------------------------------------------------------------------------
260+
# Dataloader state persistence (two-phase apply)
261+
# ---------------------------------------------------------------------------
262+
263+
264+
def _make_mock_mgr(tmp_path, monkeypatch):
265+
"""Build a CheckpointManager with DCP calls mocked out (no distributed)."""
266+
from unittest.mock import MagicMock
267+
268+
from kempnerforge.checkpoint.manager import CheckpointManager
269+
270+
model = torch.nn.Linear(4, 4)
271+
opt = torch.optim.SGD(model.parameters(), lr=0.1)
272+
config = CheckpointConfig(dir=str(tmp_path), keep_last_n=5)
273+
mgr = CheckpointManager(config, model, opt)
274+
monkeypatch.setattr(mgr._async_ckpt, "save", MagicMock())
275+
monkeypatch.setattr("kempnerforge.checkpoint.manager.dcp.load", MagicMock())
276+
return mgr
277+
278+
279+
class TestDataloaderStatePersistence:
280+
"""Round-trip coverage for dataloader state across save -> load -> apply.
281+
282+
Training loops call load() before constructing the dataloader (the loader
283+
depends on phase/annealing state that load() restores). Load stashes the
284+
dataloader state; apply_dataloader_state() restores it into the freshly
285+
built loader.
286+
"""
287+
288+
def test_apply_no_op_when_nothing_pending(self, tmp_path, monkeypatch):
289+
from unittest.mock import MagicMock
290+
291+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
292+
loader = MagicMock(spec=["load_state_dict"])
293+
mgr.apply_dataloader_state(loader)
294+
loader.load_state_dict.assert_not_called()
295+
296+
def test_apply_restores_state_to_stateful_loader(self, tmp_path, monkeypatch):
297+
from unittest.mock import MagicMock
298+
299+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
300+
stashed = {"epoch": 3, "batches_yielded": 100, "sampler": {"epoch": 3, "skip_samples": 0}}
301+
mgr._pending_dataloader_state = stashed
302+
303+
loader = MagicMock(spec=["load_state_dict"])
304+
mgr.apply_dataloader_state(loader)
305+
306+
loader.load_state_dict.assert_called_once_with(stashed)
307+
assert mgr._pending_dataloader_state is None
308+
309+
def test_apply_clears_state_for_non_stateful_loader(self, tmp_path, monkeypatch):
310+
"""Prevent the stashed state from leaking into a later (stateful) loader."""
311+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
312+
mgr._pending_dataloader_state = {"epoch": 1}
313+
314+
class PlainLoader: # no load_state_dict method
315+
pass
316+
317+
mgr.apply_dataloader_state(PlainLoader())
318+
assert mgr._pending_dataloader_state is None
319+
320+
def test_apply_clears_state_for_none_loader(self, tmp_path, monkeypatch):
321+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
322+
mgr._pending_dataloader_state = {"epoch": 1}
323+
mgr.apply_dataloader_state(None)
324+
assert mgr._pending_dataloader_state is None
325+
326+
def test_save_persists_dataloader_state(self, tmp_path, monkeypatch):
327+
"""save() must include dataloader state when a stateful loader is passed."""
328+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
329+
330+
class Loader:
331+
def state_dict(self):
332+
return {"epoch": 4, "batches_yielded": 200}
333+
334+
mgr.save(step=1, tokens_seen=64, dataloader=Loader())
335+
saved = torch.load(tmp_path / "step_1" / "train_state.pt", weights_only=False)
336+
assert saved["dataloader"] == {"epoch": 4, "batches_yielded": 200}
337+
338+
def test_load_stashes_dataloader_state_when_no_loader_provided(self, tmp_path, monkeypatch):
339+
"""load(dataloader=None) must stash the dataloader state for later apply."""
340+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
341+
ckpt_dir = tmp_path / "step_1"
342+
ckpt_dir.mkdir()
343+
saved_state = {"epoch": 2, "batches_yielded": 50}
344+
torch.save(
345+
{
346+
"step": 1,
347+
"tokens_seen": 64,
348+
"rng": get_rng_state(),
349+
"dataloader": saved_state,
350+
},
351+
ckpt_dir / "train_state.pt",
352+
)
353+
354+
step, tokens, _ = mgr.load(path=str(ckpt_dir))
355+
356+
assert step == 1
357+
assert tokens == 64
358+
assert mgr._pending_dataloader_state == saved_state
359+
360+
def test_load_restores_directly_when_loader_provided(self, tmp_path, monkeypatch):
361+
"""load(dataloader=X) must restore directly and leave pending state empty."""
362+
from unittest.mock import MagicMock
363+
364+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
365+
ckpt_dir = tmp_path / "step_1"
366+
ckpt_dir.mkdir()
367+
saved_state = {"epoch": 2, "batches_yielded": 50}
368+
torch.save(
369+
{
370+
"step": 1,
371+
"tokens_seen": 64,
372+
"rng": get_rng_state(),
373+
"dataloader": saved_state,
374+
},
375+
ckpt_dir / "train_state.pt",
376+
)
377+
378+
loader = MagicMock(spec=["load_state_dict"])
379+
mgr.load(path=str(ckpt_dir), dataloader=loader)
380+
381+
loader.load_state_dict.assert_called_once_with(saved_state)
382+
assert mgr._pending_dataloader_state is None
383+
384+
def test_load_no_stash_when_no_dataloader_key(self, tmp_path, monkeypatch):
385+
"""Missing dataloader key in train_state leaves pending state empty."""
386+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
387+
ckpt_dir = tmp_path / "step_1"
388+
ckpt_dir.mkdir()
389+
torch.save(
390+
{"step": 1, "tokens_seen": 64, "rng": get_rng_state()},
391+
ckpt_dir / "train_state.pt",
392+
)
393+
394+
mgr.load(path=str(ckpt_dir))
395+
assert mgr._pending_dataloader_state is None
396+
397+
def test_round_trip_save_load_apply(self, tmp_path, monkeypatch):
398+
"""Save with loader, load without loader, apply to new loader — state flows through."""
399+
mgr = _make_mock_mgr(tmp_path, monkeypatch)
400+
401+
captured: dict[str, dict] = {}
402+
403+
class RecorderLoader:
404+
def __init__(self, initial: dict) -> None:
405+
self._state = initial
406+
407+
def state_dict(self) -> dict:
408+
return self._state
409+
410+
def load_state_dict(self, state: dict) -> None:
411+
captured["restored"] = state
412+
413+
saver = RecorderLoader({"epoch": 7, "batches_yielded": 333})
414+
mgr.save(step=5, tokens_seen=128, dataloader=saver)
415+
416+
# Simulate a fresh process: build a new manager and load without loader.
417+
mgr2 = _make_mock_mgr(tmp_path, monkeypatch)
418+
step, tokens, _ = mgr2.load(path=str(tmp_path / "step_5"))
419+
assert step == 5
420+
assert tokens == 128
421+
assert mgr2._pending_dataloader_state == {"epoch": 7, "batches_yielded": 333}
422+
423+
# Build loader after load() and apply the stashed state.
424+
restorer = RecorderLoader({"epoch": 0, "batches_yielded": 0})
425+
mgr2.apply_dataloader_state(restorer)
426+
assert captured["restored"] == {"epoch": 7, "batches_yielded": 333}
427+
assert mgr2._pending_dataloader_state is None

0 commit comments

Comments
 (0)