diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 92fdea2ad4d..17b9a2020fc 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2712,6 +2712,8 @@ def _iter_pytorch(self): } if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]: ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"]) + # re-point at the live ex_iterable state so progress tracking + self._state_dict["examples_iterable"] = ex_iterable._state_dict if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table): formatter = get_formatter(self._formatting.format_type, features=self.features) @@ -2796,6 +2798,8 @@ def _prepare_ex_iterable_for_iteration( } if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]: ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"]) + # re-point at the live ex_iterable state so progress tracking + self._state_dict["examples_iterable"] = ex_iterable._state_dict return ex_iterable def __iter__(self): diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 3fd22d63809..3c9a0ee32a6 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -2833,6 +2833,39 @@ def test_resume_dataloader(dataset: IterableDataset): assert remaining == list(dl) +@require_torchdata_stateful_dataloader +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_resume_dataloader_twice(num_workers): + from torchdata.stateful_dataloader import StatefulDataLoader + + ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": [f"file{i}.txt" for i in range(4)]}) + dataset = IterableDataset(ex_iterable) + + def make_dataloader(): + return StatefulDataLoader(dataset, batch_size=None, num_workers=num_workers) + + all_examples = list(make_dataloader()) + + # consume 2 examples, then checkpoint #1 + dl = make_dataloader() + it = iter(dl) + consumed = [next(it) for _ in range(2)] + state_1 = dl.state_dict() + + # resume from #1, consume 2 more, then checkpoint #2 (taken from a resumed loader) + dl = make_dataloader() + dl.load_state_dict(state_1) + it = iter(dl) + consumed += [next(it) for _ in range(2)] + state_2 = dl.state_dict() + + # resuming from #2 must continue from where it left off, not restart from the beginning + dl = make_dataloader() + dl.load_state_dict(state_2) + remainder = list(dl) + assert consumed + remainder == all_examples + + @pytest.mark.parametrize("num_shards", [1, 2, 3, 7]) def test_iterable_dataset_batch(num_shards: int): # Create a simple IterableDataset