Skip to content

Commit 9eff5cc

Browse files
authored
feat: add shuffle parameter to train_test_split (Lightning-AI#675)
* feat: add shuffle parameter to train_test_split * fixup! feat: add shuffle parameter to train_test_split
1 parent f739826 commit 9eff5cc

2 files changed

Lines changed: 37 additions & 4 deletions

File tree

src/litdata/utilities/train_test_split.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def train_test_split(
15-
streaming_dataset: StreamingDataset, splits: list[float], seed: int = 42
15+
streaming_dataset: StreamingDataset, splits: list[float], seed: int = 42, shuffle: bool = True
1616
) -> list[StreamingDataset]:
1717
"""Splits a StreamingDataset into multiple subsets for training, testing, and validation.
1818
@@ -24,6 +24,7 @@ def train_test_split(
2424
splits: A list of floats representing the proportion of data to be allocated to each split
2525
(e.g., [0.8, 0.1, 0.1] for 80% training, 10% testing, and 10% validation).
2626
seed: An integer used to seed the random number generator for reproducibility.
27+
shuffle: A boolean indicating whether to shuffle the data before splitting.
2728
2829
Returns:
2930
List[StreamingDataset]: A list of StreamingDataset instances, where each element represents a split of the
@@ -71,9 +72,10 @@ def train_test_split(
7172

7273
dataset_length = sum([my_roi[1] - my_roi[0] for my_roi in dummy_subsampled_roi])
7374

74-
subsampled_chunks, dummy_subsampled_roi = shuffle_lists_together(
75-
subsampled_chunks, dummy_subsampled_roi, np.random.RandomState([seed])
76-
)
75+
if shuffle:
76+
subsampled_chunks, dummy_subsampled_roi = shuffle_lists_together(
77+
subsampled_chunks, dummy_subsampled_roi, np.random.RandomState([seed])
78+
)
7779

7880
item_count_list = [int(dataset_length * split) for split in splits]
7981

tests/utilities/test_train_test_split.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,34 @@ def test_train_test_split_with_streaming_dataloader(tmpdir, compression):
108108
for curr_idx in _dl:
109109
assert curr_idx not in visited_indices
110110
visited_indices.add(curr_idx)
111+
112+
113+
@pytest.mark.parametrize(
114+
"compression",
115+
[
116+
pytest.param(None),
117+
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
118+
],
119+
)
120+
def test_train_test_split_with_shuffle_parameter(tmpdir, compression):
121+
cache = Cache(str(tmpdir), chunk_size=10, compression=compression)
122+
for i in range(100):
123+
cache[i] = i
124+
cache.done()
125+
cache.merge()
126+
127+
my_streaming_dataset = StreamingDataset(input_dir=str(tmpdir))
128+
129+
train_shuffled, test_shuffled = train_test_split(my_streaming_dataset, splits=[0.8, 0.2], shuffle=True)
130+
train_no_shuffle, test_no_shuffle = train_test_split(my_streaming_dataset, splits=[0.8, 0.2], shuffle=False)
131+
132+
assert len(train_shuffled) == 80
133+
assert len(train_no_shuffle) == 80
134+
assert len(test_shuffled) == 20
135+
assert len(test_no_shuffle) == 20
136+
137+
shuffled_combined = train_shuffled.subsampled_files + test_shuffled.subsampled_files
138+
no_shuffle_combined = train_no_shuffle.subsampled_files + test_no_shuffle.subsampled_files
139+
assert shuffled_combined != no_shuffle_combined
140+
141+
assert no_shuffle_combined == my_streaming_dataset.subsampled_files

0 commit comments

Comments
 (0)