StatefulDataLoader misaligns the data order on the second (and later)
save/load cycles within a single epoch:
__iter__ sets self._batches_yielded = 0 on every call.
load_state_dict reads batches_yielded into a local variable and
never writes it back onto self.
- The sampler exposes a single-shot
set_skip that is consumed the first
time iter() advances.
Sequence that breaks:
- Run 1 yields 30 batches,
state_dict() records batches_yielded=30.
- Run 2
load_state_dict applies sampler.set_skip(30*bs). It then calls
iter(), which consumes the sampler skip, and _batches_yielded is
reset to 0.
- Run 2 yields 10 more batches.
state_dict() records
batches_yielded=10, not 40. The sampler state also only reflects the
10-batch delta because it was re-seeded at iter().
- Run 3
load_state_dict applies sampler.set_skip(10*bs). The loader
skips 10 batches, not 40.
On busy clusters with multiple preemptions per epoch, this silently shifts
the data order on every resume after the first.
Fix
load_state_dict writes batches_yielded to self._batches_yielded.
__iter__ re-applies sampler.set_skip(self._batches_yielded * self.batch_size) and no longer zeros the counter. _batches_yielded
now advances monotonically within the epoch and resets only when
StopIteration fires.
Coverage
tests/integration/test_data_resumption.py::test_double_resume_within_same_epoch:
three-way resume (30 -> 40 -> ground-truth batches 40..49) asserts the
batches on the second resume match a single continuous run batch-for-batch.
StatefulDataLoadermisaligns the data order on the second (and later)save/load cycles within a single epoch:
__iter__setsself._batches_yielded = 0on every call.load_state_dictreadsbatches_yieldedinto a local variable andnever writes it back onto
self.set_skipthat is consumed the firsttime
iter()advances.Sequence that breaks:
state_dict()recordsbatches_yielded=30.load_state_dictappliessampler.set_skip(30*bs). It then callsiter(), which consumes the sampler skip, and_batches_yieldedisreset to 0.
state_dict()recordsbatches_yielded=10, not 40. The sampler state also only reflects the10-batch delta because it was re-seeded at
iter().load_state_dictappliessampler.set_skip(10*bs). The loaderskips 10 batches, not 40.
On busy clusters with multiple preemptions per epoch, this silently shifts
the data order on every resume after the first.
Fix
load_state_dictwritesbatches_yieldedtoself._batches_yielded.__iter__re-appliessampler.set_skip(self._batches_yielded * self.batch_size)and no longer zeros the counter._batches_yieldednow advances monotonically within the epoch and resets only when
StopIterationfires.Coverage
tests/integration/test_data_resumption.py::test_double_resume_within_same_epoch:three-way resume (30 -> 40 -> ground-truth batches 40..49) asserts the
batches on the second resume match a single continuous run batch-for-batch.