@@ -254,3 +254,174 @@ def test_save_waits_for_previous(self, monkeypatch, tmp_path):
254254 ckpt .save ({"model" : {}}, checkpoint_id = str (tmp_path / "step_2" ))
255255 # First future should have been waited on before second save
256256 mock_future1 .result .assert_called_once ()
257+
258+
259+ # ---------------------------------------------------------------------------
260+ # Dataloader state persistence (two-phase apply)
261+ # ---------------------------------------------------------------------------
262+
263+
264+ def _make_mock_mgr (tmp_path , monkeypatch ):
265+ """Build a CheckpointManager with DCP calls mocked out (no distributed)."""
266+ from unittest .mock import MagicMock
267+
268+ from kempnerforge .checkpoint .manager import CheckpointManager
269+
270+ model = torch .nn .Linear (4 , 4 )
271+ opt = torch .optim .SGD (model .parameters (), lr = 0.1 )
272+ config = CheckpointConfig (dir = str (tmp_path ), keep_last_n = 5 )
273+ mgr = CheckpointManager (config , model , opt )
274+ monkeypatch .setattr (mgr ._async_ckpt , "save" , MagicMock ())
275+ monkeypatch .setattr ("kempnerforge.checkpoint.manager.dcp.load" , MagicMock ())
276+ return mgr
277+
278+
279+ class TestDataloaderStatePersistence :
280+ """Round-trip coverage for dataloader state across save -> load -> apply.
281+
282+ Training loops call load() before constructing the dataloader (the loader
283+ depends on phase/annealing state that load() restores). Load stashes the
284+ dataloader state; apply_dataloader_state() restores it into the freshly
285+ built loader.
286+ """
287+
288+ def test_apply_no_op_when_nothing_pending (self , tmp_path , monkeypatch ):
289+ from unittest .mock import MagicMock
290+
291+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
292+ loader = MagicMock (spec = ["load_state_dict" ])
293+ mgr .apply_dataloader_state (loader )
294+ loader .load_state_dict .assert_not_called ()
295+
296+ def test_apply_restores_state_to_stateful_loader (self , tmp_path , monkeypatch ):
297+ from unittest .mock import MagicMock
298+
299+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
300+ stashed = {"epoch" : 3 , "batches_yielded" : 100 , "sampler" : {"epoch" : 3 , "skip_samples" : 0 }}
301+ mgr ._pending_dataloader_state = stashed
302+
303+ loader = MagicMock (spec = ["load_state_dict" ])
304+ mgr .apply_dataloader_state (loader )
305+
306+ loader .load_state_dict .assert_called_once_with (stashed )
307+ assert mgr ._pending_dataloader_state is None
308+
309+ def test_apply_clears_state_for_non_stateful_loader (self , tmp_path , monkeypatch ):
310+ """Prevent the stashed state from leaking into a later (stateful) loader."""
311+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
312+ mgr ._pending_dataloader_state = {"epoch" : 1 }
313+
314+ class PlainLoader : # no load_state_dict method
315+ pass
316+
317+ mgr .apply_dataloader_state (PlainLoader ())
318+ assert mgr ._pending_dataloader_state is None
319+
320+ def test_apply_clears_state_for_none_loader (self , tmp_path , monkeypatch ):
321+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
322+ mgr ._pending_dataloader_state = {"epoch" : 1 }
323+ mgr .apply_dataloader_state (None )
324+ assert mgr ._pending_dataloader_state is None
325+
326+ def test_save_persists_dataloader_state (self , tmp_path , monkeypatch ):
327+ """save() must include dataloader state when a stateful loader is passed."""
328+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
329+
330+ class Loader :
331+ def state_dict (self ):
332+ return {"epoch" : 4 , "batches_yielded" : 200 }
333+
334+ mgr .save (step = 1 , tokens_seen = 64 , dataloader = Loader ())
335+ saved = torch .load (tmp_path / "step_1" / "train_state.pt" , weights_only = False )
336+ assert saved ["dataloader" ] == {"epoch" : 4 , "batches_yielded" : 200 }
337+
338+ def test_load_stashes_dataloader_state_when_no_loader_provided (self , tmp_path , monkeypatch ):
339+ """load(dataloader=None) must stash the dataloader state for later apply."""
340+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
341+ ckpt_dir = tmp_path / "step_1"
342+ ckpt_dir .mkdir ()
343+ saved_state = {"epoch" : 2 , "batches_yielded" : 50 }
344+ torch .save (
345+ {
346+ "step" : 1 ,
347+ "tokens_seen" : 64 ,
348+ "rng" : get_rng_state (),
349+ "dataloader" : saved_state ,
350+ },
351+ ckpt_dir / "train_state.pt" ,
352+ )
353+
354+ step , tokens , _ = mgr .load (path = str (ckpt_dir ))
355+
356+ assert step == 1
357+ assert tokens == 64
358+ assert mgr ._pending_dataloader_state == saved_state
359+
360+ def test_load_restores_directly_when_loader_provided (self , tmp_path , monkeypatch ):
361+ """load(dataloader=X) must restore directly and leave pending state empty."""
362+ from unittest .mock import MagicMock
363+
364+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
365+ ckpt_dir = tmp_path / "step_1"
366+ ckpt_dir .mkdir ()
367+ saved_state = {"epoch" : 2 , "batches_yielded" : 50 }
368+ torch .save (
369+ {
370+ "step" : 1 ,
371+ "tokens_seen" : 64 ,
372+ "rng" : get_rng_state (),
373+ "dataloader" : saved_state ,
374+ },
375+ ckpt_dir / "train_state.pt" ,
376+ )
377+
378+ loader = MagicMock (spec = ["load_state_dict" ])
379+ mgr .load (path = str (ckpt_dir ), dataloader = loader )
380+
381+ loader .load_state_dict .assert_called_once_with (saved_state )
382+ assert mgr ._pending_dataloader_state is None
383+
384+ def test_load_no_stash_when_no_dataloader_key (self , tmp_path , monkeypatch ):
385+ """Missing dataloader key in train_state leaves pending state empty."""
386+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
387+ ckpt_dir = tmp_path / "step_1"
388+ ckpt_dir .mkdir ()
389+ torch .save (
390+ {"step" : 1 , "tokens_seen" : 64 , "rng" : get_rng_state ()},
391+ ckpt_dir / "train_state.pt" ,
392+ )
393+
394+ mgr .load (path = str (ckpt_dir ))
395+ assert mgr ._pending_dataloader_state is None
396+
397+ def test_round_trip_save_load_apply (self , tmp_path , monkeypatch ):
398+ """Save with loader, load without loader, apply to new loader — state flows through."""
399+ mgr = _make_mock_mgr (tmp_path , monkeypatch )
400+
401+ captured : dict [str , dict ] = {}
402+
403+ class RecorderLoader :
404+ def __init__ (self , initial : dict ) -> None :
405+ self ._state = initial
406+
407+ def state_dict (self ) -> dict :
408+ return self ._state
409+
410+ def load_state_dict (self , state : dict ) -> None :
411+ captured ["restored" ] = state
412+
413+ saver = RecorderLoader ({"epoch" : 7 , "batches_yielded" : 333 })
414+ mgr .save (step = 5 , tokens_seen = 128 , dataloader = saver )
415+
416+ # Simulate a fresh process: build a new manager and load without loader.
417+ mgr2 = _make_mock_mgr (tmp_path , monkeypatch )
418+ step , tokens , _ = mgr2 .load (path = str (tmp_path / "step_5" ))
419+ assert step == 5
420+ assert tokens == 128
421+ assert mgr2 ._pending_dataloader_state == {"epoch" : 7 , "batches_yielded" : 333 }
422+
423+ # Build loader after load() and apply the stashed state.
424+ restorer = RecorderLoader ({"epoch" : 0 , "batches_yielded" : 0 })
425+ mgr2 .apply_dataloader_state (restorer )
426+ assert captured ["restored" ] == {"epoch" : 7 , "batches_yielded" : 333 }
427+ assert mgr2 ._pending_dataloader_state is None
0 commit comments