99from modalities .dataloader .filter_packed_data import filter_dataset
1010
1111
12- def test_creates_output_file (tmp_path : Path , packed_data_paths : Path ):
12+ def test_creates_output_file (tmp_path : Path , packed_data_path : Path ):
1313 output_path = Path (tmp_path , "output.pbin" )
1414 filter_dataset (
15- src_path = packed_data_paths , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
15+ src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
1616 )
1717 assert output_path .exists ()
1818
1919
20- def test_filtered_data_has_expected_length (tmp_path : Path , packed_data_paths : Path ):
20+ def test_filtered_data_has_expected_length (tmp_path : Path , packed_data_path : Path ):
2121 output_path = Path (tmp_path , "output.pbin" )
2222 filter_dataset (
23- src_path = packed_data_paths , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
23+ src_path = packed_data_path , dst_path = output_path , filter_func = accept_even_indices , sample_key = "input_ids"
2424 )
25- original_data = PackedMemMapDatasetBase (packed_data_paths , sample_key = "input_ids" )
25+ original_data = PackedMemMapDatasetBase (packed_data_path , sample_key = "input_ids" )
2626 filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
2727 assert (
2828 len (filtered_data ) == len (original_data ) // 2 + len (original_data ) % 2
@@ -39,22 +39,20 @@ def test_filtered_data_has_expected_content(tmp_path: Path, dummy_packed_data_pa
3939 assert filtered_data [1 ]["input_ids" ].tolist () == list (range (64 // 4 , (64 + 12 ) // 4 ))
4040
4141
42- def test_always_true_filtered_data_has_identical_file_hash (tmp_path : Path , packed_data_paths : Path ):
42+ def test_always_true_filtered_data_has_identical_file_hash (tmp_path : Path , packed_data_path : Path ):
4343 output_path = Path (tmp_path , "output.pbin" )
44- filter_dataset (src_path = packed_data_paths , dst_path = output_path , filter_func = lambda x : True , sample_key = "input_ids" )
45- with open (packed_data_paths , "rb" ) as f_in , open (output_path , "rb" ) as f_out :
44+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : True , sample_key = "input_ids" )
45+ with open (packed_data_path , "rb" ) as f_in , open (output_path , "rb" ) as f_out :
4646 original_hash = hashlib .sha256 (f_in .read ()).hexdigest ()
4747 filtered_hash = hashlib .sha256 (f_out .read ()).hexdigest ()
4848 assert (
4949 original_hash == filtered_hash
5050 ), "Filtered data should have the same hash as the original data when no filtering is applied."
5151
5252
53- def test_always_false_filtered_data_produces_valid_file (tmp_path : Path , packed_data_paths : Path ):
53+ def test_always_false_filtered_data_produces_valid_file (tmp_path : Path , packed_data_path : Path ):
5454 output_path = Path (tmp_path , "output.pbin" )
55- filter_dataset (
56- src_path = packed_data_paths , dst_path = output_path , filter_func = lambda x : False , sample_key = "input_ids"
57- )
55+ filter_dataset (src_path = packed_data_path , dst_path = output_path , filter_func = lambda x : False , sample_key = "input_ids" )
5856 filtered_data = PackedMemMapDatasetBase (output_path , sample_key = "input_ids" )
5957 assert len (filtered_data ) == 0 , "Filtered data should be empty when all samples are filtered out."
6058 assert output_path .stat ().st_size > 0 , "Output file should not be empty even if no samples are included."
@@ -74,6 +72,6 @@ def accept_even_indices(idx_content: tuple[int, dict[str, NDArray[np.int_]]]) ->
7472
7573
7674@pytest .fixture (params = [0 , 1 ])
77- def packed_data_paths (dummy_packed_data_path : Path , request : pytest .FixtureRequest ) -> Path :
78- path_options = [dummy_packed_data_path , Path ("tests" , " data" , " datasets" , " lorem_ipsum_long.pbin" )]
75+ def packed_data_path (dummy_packed_data_path : Path , request : pytest .FixtureRequest ) -> Path :
76+ path_options = [dummy_packed_data_path , Path ("tests/ data/ datasets/ lorem_ipsum_long.pbin" )]
7977 return path_options [request .param ]
0 commit comments