Add CI tests for dataloader mid-iteration resume#51
Conversation
Covers the core invariant that `OnlineTokenizedIterableDataset` +
`ParallelAwareDataLoader` must preserve: tokens yielded after resuming
from a checkpoint match those from an uninterrupted run. Parametrized
across num_workers ∈ {0, 2, 4} and rank/world configurations, plus a
multi-rank merge/restore case that mimics the per-rank key scheme used
at DCP save time. CI runs on Python 3.10 and 3.12.
There was a problem hiding this comment.
Code Review
This pull request adds a test harness setup to stub torchtitan for CI environments and introduces a comprehensive test suite for OnlineTokenizedIterableDataset and ParallelAwareDataLoader, focusing on mid-iteration resume and state isolation. Feedback suggests catching ImportError specifically in the stubbing logic, parametrizing parallel resume tests for multiple worker counts, and optimizing iterator usage in sequence length preservation tests.
| import torchtitan.tools.logging # noqa: F401 | ||
| import torchtitan.tools.utils # noqa: F401 | ||
| return | ||
| except Exception: |
There was a problem hiding this comment.
| _assert_equal_sequences(before, after) | ||
|
|
||
|
|
||
| def test_parallel_aware_resume_per_rank(num_workers=0): |
There was a problem hiding this comment.
This test is not parametrized for num_workers, unlike other resume tests in this file. Since the PR description mentions testing across num_workers ∈ {0, 2, 4}, and multi-worker state management is a common source of bugs in dataloaders, it would be better to include these cases here as well.
@pytest.mark.parametrize("num_workers", [0, 2, 4])
def test_parallel_aware_resume_per_rank(num_workers):| for _ in range(5): | ||
| batch = next(iter(dl2)) |
There was a problem hiding this comment.
Calling iter(dl2) inside the loop creates a new iterator on every iteration. While StatefulDataLoader and OnlineTokenizedIterableDataset might partially handle this by updating internal state, it is inefficient and unconventional. It is better to create the iterator once and use it throughout the loop to ensure consecutive batches are checked correctly.
| for _ in range(5): | |
| batch = next(iter(dl2)) | |
| it2 = iter(dl2) | |
| for _ in range(5): | |
| batch = next(it2) |
Covers the core invariant that
OnlineTokenizedIterableDataset+ParallelAwareDataLoadermust preserve: tokens yielded after resuming from a checkpoint match those from an uninterrupted run. Parametrized across num_workers ∈ {0, 2, 4} and rank/world configurations, plus a multi-rank merge/restore case that mimics the per-rank key scheme used at DCP save time. CI runs on Python 3.10 and 3.12.