Skip to content

Commit b38bda7

Browse files
committed
refactor: Minor adjustments for review.
1 parent 002a00c commit b38bda7

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

src/modalities/dataloader/filter_packed_data.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,19 @@ def filter_dataset(
1414
src_path: Path,
1515
dst_path: Path,
1616
filter_func: Callable[[tuple[int, dict[str, NDArray[np.int_]]]], bool],
17-
sample_key: str = "input_ids",
1817
) -> None:
1918
"""
2019
Filters the dataset based on a given filter function and writes the filtered data to the destination path.
2120
Args:
22-
dst_path (Path): The path where the filtered dataset will be written.
2321
src_path (Path): The path to the source dataset to filter.
22+
dst_path (Path): The path where the filtered dataset will be written.
2423
filter_func (Callable[[tuple[int, dict[str, NDArray[np.int_]]]], bool]):
2524
A function that takes a sample index and its content and returns
2625
True if the sample should be included, False otherwise.
27-
sample_key (str): The key in the dataset samples to filter on, default is "input_ids".
2826
Returns:
2927
None
3028
"""
29+
sample_key: str = "input_ids"
3130
index_list: list[tuple[int, int]] = []
3231
source_data = PackedMemMapDatasetBase(src_path, sample_key=sample_key, load_index=True)
3332
with dst_path.open("wb") as f_out:

0 commit comments

Comments
 (0)