Skip to content

Bug Fix: Resuming Twice Resets the Dataloader#8295

Open
francesco-bertolotti wants to merge 2 commits into
huggingface:mainfrom
francesco-bertolotti:f14-resume-fix
Open

Bug Fix: Resuming Twice Resets the Dataloader#8295
francesco-bertolotti wants to merge 2 commits into
huggingface:mainfrom
francesco-bertolotti:f14-resume-fix

Conversation

@francesco-bertolotti

Copy link
Copy Markdown

Summary

For a streaming IterableDataset, calling load_state_dict() and then continuing to iterate leaves state_dict() at the initial position (shard_idx=0, shard_example_idx=0). Reading resumes from the correct place, but the the state never advances again.

The practical consequence shows up on the second consecutive resume of a training run:

  1. An original run tracks state correctly and writes a healthy checkpoint.
  2. Resuming from it reads correctly but every checkpoint writes a stale zero-initialized state_dict
  3. Resuming from that checkpoint restarts the data stream from the very beginning.

In our case, a torchtitan resumed-twice run over FineWeb produced a different training loss compared to the resume-once run.

Root cause

The resume call sites in IterableDataset (_iter_pytorch for num_workers>=1 and _prepare_ex_iterable_for_iteration for num_workers=0) build self._state_dict so that self._state_dict["examples_iterable"] references the example-iterable's own state dict, and then load the resume state into that example-iterable. _BaseExamplesIterable.load_state_dict calls self._init_state_dict(), which rebinds self._state_dict to a brand-new object. After the load, dataset._state_dict["examples_iterable"] still points at the stale zero-initialized state dict. During iteration the example-iterable mutates its state dict, while dataset.state_dict() keeps deep-copying the stale one, so the reported position is frozen at the zero.

On a fresh run the load_state_dict branch is skipped, the reference stays linked, and tracking works, which is why the bug only surfaces after a resume.

The fix

Re-point self._state_dict["examples_iterable"] at the live example-iterable state immediately after the resume load, at both IterableDataset resume sites:

if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]:
    ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"])
    self._state_dict["examples_iterable"] = ex_iterable._state_dict   # re-link after the rebind

Reproduction

import datasets
def gen():
    for i in range(10): yield {"id": i}
ds = datasets.IterableDataset.from_generator(gen)
full = [x["id"] for x in ds]
it = iter(ds); [next(it) for _ in range(3)]; sd1 = ds.state_dict()
ds.load_state_dict(sd1); it = iter(ds); seen = [next(it)["id"] for _ in range(3)]; sd2 = ds.state_dict()
ds.load_state_dict(sd2); after = [x["id"] for x in ds]
print("reference       :", full)
print("after 1st resume:", seen, "(reads correctly)")
print("after 2nd resume:", after)
print(">>> BUG: 2nd resume restarted from the beginning" if after == full else ">>> OK: continued")

Tests

  • Added tests/test_iterable_dataset.py::test_resume_dataloader_twice, a test that takes a checkpoint, resumes from it, takes a second checkpoint from the resumed loader, and asserts the second checkpoint continues iteration from the right place instead of restarting. It fails on main and passes with this fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant