1111
1212def test_creates_output_file (tmp_path : Path , packed_data_path : Path ):
1313 output_path = Path (tmp_path , "output.pbin" )
14- filter_dataset (
15- src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
16- )
14+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices )
1715 assert output_path .exists ()
1816
1917
2018def test_filtered_data_has_expected_length (tmp_path : Path , packed_data_path : Path ):
2119 output_path = Path (tmp_path , "output.pbin" )
22- filter_dataset (
23- src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
24- )
20+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices )
2521 original_data = PackedMemMapDatasetBase (packed_data_path , sample_key = "input_ids" )
2622 filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
2723 assert (
@@ -31,17 +27,15 @@ def test_filtered_data_has_expected_length(tmp_path: Path, packed_data_path: Pat
3127
3228def test_filtered_data_has_expected_content (tmp_path : Path , dummy_packed_data_path : Path ):
3329 output_path = Path (tmp_path , "output.pbin" )
34- filter_dataset (
35- src_path = dummy_packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
36- )
30+ filter_dataset (src_path = dummy_packed_data_path , dst_path = output_path , filter_func = accept_even_indices )
3731 filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
3832 assert filtered_data [0 ]["input_ids" ].tolist () == list (range (24 // 4 ))
3933 assert filtered_data [1 ]["input_ids" ].tolist () == list (range (64 // 4 , (64 + 12 ) // 4 ))
4034
4135
4236def test_always_true_filtered_data_has_identical_file_hash (tmp_path : Path , packed_data_path : Path ):
4337 output_path = Path (tmp_path , "output.pbin" )
44- filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : True , sample_key = "input_ids" )
38+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : True )
4539 with open (packed_data_path , "rb" ) as f_in , open (output_path , "rb" ) as f_out :
4640 original_hash = hashlib .sha256 (f_in .read ()).hexdigest ()
4741 filtered_hash = hashlib .sha256 (f_out .read ()).hexdigest ()
@@ -52,7 +46,7 @@ def test_always_true_filtered_data_has_identical_file_hash(tmp_path: Path, packe
5246
5347def test_always_false_filtered_data_produces_valid_file (tmp_path : Path , packed_data_path : Path ):
5448 output_path = Path (tmp_path , "output.pbin" )
55- filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : False , sample_key = "input_ids" )
49+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : False )
5650 filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
5751 assert len (filtered_data ) == 0 , "Filtered data should be empty when all samples are filtered out."
5852 assert output_path .stat ().st_size > 0 , "Output file should not be empty even if no samples are included."
0 commit comments