From 2b16d087d327f210b8685943c554ef2873493524 Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Tue, 30 Jun 2026 15:12:24 +0200 Subject: [PATCH 1/2] re-sync state_dict after resume --- src/datasets/iterable_dataset.py | 4 ++++ tests/test_iterable_dataset.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 92fdea2ad4d..17b9a2020fc 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2712,6 +2712,8 @@ def _iter_pytorch(self): } if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]: ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"]) + # re-point at the live ex_iterable state so progress tracking + self._state_dict["examples_iterable"] = ex_iterable._state_dict if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): formatter = get_formatter(self._formatting.format_type, features=self.features) @@ -2796,6 +2798,8 @@ def _prepare_ex_iterable_for_iteration( } if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]: ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"]) + # re-point at the live ex_iterable state so progress tracking + self._state_dict["examples_iterable"] = ex_iterable._state_dict return ex_iterable def __iter__(self): diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 3fd22d63809..925a3436c86 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -2833,6 +2833,42 @@ def test_resume_dataloader(dataset: IterableDataset): assert remaining == list(dl) +@require_torchdata_stateful_dataloader +def test_resume_dataloader_twice(dataset: IterableDataset): + # Regression test: a checkpoint taken from an already-resumed dataloader must itself + # resume correctly. Previously, once load_state_dict() had been called the dataset's + # state_dict froze at the initial position (the example-iterable rebound its _state_dict + # and detached from IterableDataset._state_dict["examples_iterable"]), so the second + # checkpoint restarted iteration from the beginning. + from torchdata.stateful_dataloader import StatefulDataLoader + + all_examples = list(StatefulDataLoader(dataset)) + assert len(all_examples) >= 6 + + # checkpoint #1 after consuming 2 examples + dl = StatefulDataLoader(dataset) + for i, _ in enumerate(dl): + if i == 1: + state_1 = dl.state_dict() + break + + # resume from #1, consume 2 more, then checkpoint #2 (taken from a resumed loader) + dl = StatefulDataLoader(dataset) + dl.load_state_dict(state_1) + resumed_after_1 = [] + for i, x in enumerate(dl): + resumed_after_1.append(x) + if i == 1: + state_2 = dl.state_dict() + break + assert resumed_after_1 == all_examples[2:4] + + # resume from #2: must continue from example 4, not restart from the beginning + dl = StatefulDataLoader(dataset) + dl.load_state_dict(state_2) + assert list(dl) == all_examples[4:] + + @pytest.mark.parametrize("num_shards", [1, 2, 3, 7]) def test_iterable_dataset_batch(num_shards: int): # Create a simple IterableDataset From 1e70aa0a0d8e8d4483ee9d8b6ea12879f3b2d746 Mon Sep 17 00:00:00 2001 From: Francesco Bertolotti Date: Tue, 30 Jun 2026 17:18:19 +0200 Subject: [PATCH 2/2] num workers test --- tests/test_iterable_dataset.py | 47 ++++++++++++++++------------------ 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 925a3436c86..3c9a0ee32a6 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -2834,39 +2834,36 @@ def test_resume_dataloader(dataset: IterableDataset): @require_torchdata_stateful_dataloader -def test_resume_dataloader_twice(dataset: IterableDataset): - # Regression test: a checkpoint taken from an already-resumed dataloader must itself - # resume correctly. Previously, once load_state_dict() had been called the dataset's - # state_dict froze at the initial position (the example-iterable rebound its _state_dict - # and detached from IterableDataset._state_dict["examples_iterable"]), so the second - # checkpoint restarted iteration from the beginning. +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_resume_dataloader_twice(num_workers): from torchdata.stateful_dataloader import StatefulDataLoader - all_examples = list(StatefulDataLoader(dataset)) - assert len(all_examples) >= 6 + ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"file{i}.txt" for i in range(4)]}) + dataset = IterableDataset(ex_iterable) - # checkpoint #1 after consuming 2 examples - dl = StatefulDataLoader(dataset) - for i, _ in enumerate(dl): - if i == 1: - state_1 = dl.state_dict() - break + def make_dataloader(): + return StatefulDataLoader(dataset, batch_size=None, num_workers=num_workers) + + all_examples = list(make_dataloader()) + + # consume 2 examples, then checkpoint #1 + dl = make_dataloader() + it = iter(dl) + consumed = [next(it) for _ in range(2)] + state_1 = dl.state_dict() # resume from #1, consume 2 more, then checkpoint #2 (taken from a resumed loader) - dl = StatefulDataLoader(dataset) + dl = make_dataloader() dl.load_state_dict(state_1) - resumed_after_1 = [] - for i, x in enumerate(dl): - resumed_after_1.append(x) - if i == 1: - state_2 = dl.state_dict() - break - assert resumed_after_1 == all_examples[2:4] + it = iter(dl) + consumed += [next(it) for _ in range(2)] + state_2 = dl.state_dict() - # resume from #2: must continue from example 4, not restart from the beginning - dl = StatefulDataLoader(dataset) + # resuming from #2 must continue from where it left off, not restart from the beginning + dl = make_dataloader() dl.load_state_dict(state_2) - assert list(dl) == all_examples[4:] + remainder = list(dl) + assert consumed + remainder == all_examples @pytest.mark.parametrize("num_shards", [1, 2, 3, 7])