Skip to content

Commit 56cc2a6

Browse files
committed
test: Adapted filtering tests to recent changes.
1 parent b38bda7 commit 56cc2a6

1 file changed

Lines changed: 5 additions & 11 deletions

File tree

tests/dataloader/test_filter_packed_data.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,13 @@
1111

1212
def test_creates_output_file(tmp_path: Path, packed_data_path: Path):
1313
output_path = Path(tmp_path, "output.pbin")
14-
filter_dataset(
15-
src_path=packed_data_path, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
16-
)
14+
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=accept_even_indices)
1715
assert output_path.exists()
1816

1917

2018
def test_filtered_data_has_expected_length(tmp_path: Path, packed_data_path: Path):
2119
output_path = Path(tmp_path, "output.pbin")
22-
filter_dataset(
23-
src_path=packed_data_path, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
24-
)
20+
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=accept_even_indices)
2521
original_data = PackedMemMapDatasetBase(packed_data_path, sample_key="input_ids")
2622
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
2723
assert (
@@ -31,17 +27,15 @@ def test_filtered_data_has_expected_length(tmp_path: Path, packed_data_path: Pat
3127

3228
def test_filtered_data_has_expected_content(tmp_path: Path, dummy_packed_data_path: Path):
3329
output_path = Path(tmp_path, "output.pbin")
34-
filter_dataset(
35-
src_path=dummy_packed_data_path, dst_path=output_path, filter_func=accept_even_indices, sample_key="input_ids"
36-
)
30+
filter_dataset(src_path=dummy_packed_data_path, dst_path=output_path, filter_func=accept_even_indices)
3731
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
3832
assert filtered_data[0]["input_ids"].tolist() == list(range(24 // 4))
3933
assert filtered_data[1]["input_ids"].tolist() == list(range(64 // 4, (64 + 12) // 4))
4034

4135

4236
def test_always_true_filtered_data_has_identical_file_hash(tmp_path: Path, packed_data_path: Path):
4337
output_path = Path(tmp_path, "output.pbin")
44-
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=lambda x: True, sample_key="input_ids")
38+
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=lambda x: True)
4539
with open(packed_data_path, "rb") as f_in, open(output_path, "rb") as f_out:
4640
original_hash = hashlib.sha256(f_in.read()).hexdigest()
4741
filtered_hash = hashlib.sha256(f_out.read()).hexdigest()
@@ -52,7 +46,7 @@ def test_always_true_filtered_data_has_identical_file_hash(tmp_path: Path, packe
5246

5347
def test_always_false_filtered_data_produces_valid_file(tmp_path: Path, packed_data_path: Path):
5448
output_path = Path(tmp_path, "output.pbin")
55-
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=lambda x: False, sample_key="input_ids")
49+
filter_dataset(src_path=packed_data_path, dst_path=output_path, filter_func=lambda x: False)
5650
filtered_data = PackedMemMapDatasetBase(output_path, sample_key="input_ids")
5751
assert len(filtered_data) == 0, "Filtered data should be empty when all samples are filtered out."
5852
assert output_path.stat().st_size > 0, "Output file should not be empty even if no samples are included."

0 commit comments

Comments
 (0)