Skip to content

StatefulDataLoader loses alignment on second resume within an epoch #57

@mmshad

Description

@mmshad

StatefulDataLoader misaligns the data order on the second (and later)
save/load cycles within a single epoch:

  • __iter__ sets self._batches_yielded = 0 on every call.
  • load_state_dict reads batches_yielded into a local variable and
    never writes it back onto self.
  • The sampler exposes a single-shot set_skip that is consumed the first
    time iter() advances.

Sequence that breaks:

  1. Run 1 yields 30 batches, state_dict() records batches_yielded=30.
  2. Run 2 load_state_dict applies sampler.set_skip(30*bs). It then calls
    iter(), which consumes the sampler skip, and _batches_yielded is
    reset to 0.
  3. Run 2 yields 10 more batches. state_dict() records
    batches_yielded=10, not 40. The sampler state also only reflects the
    10-batch delta because it was re-seeded at iter().
  4. Run 3 load_state_dict applies sampler.set_skip(10*bs). The loader
    skips 10 batches, not 40.

On busy clusters with multiple preemptions per epoch, this silently shifts
the data order on every resume after the first.

Fix

  • load_state_dict writes batches_yielded to self._batches_yielded.
  • __iter__ re-applies sampler.set_skip(self._batches_yielded * self.batch_size) and no longer zeros the counter. _batches_yielded
    now advances monotonically within the epoch and resets only when
    StopIteration fires.

Coverage

tests/integration/test_data_resumption.py::test_double_resume_within_same_epoch:
three-way resume (30 -> 40 -> ground-truth batches 40..49) asserts the
batches on the second resume match a single continuous run batch-for-batch.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions