@@ -211,22 +211,43 @@ def _auto_preprocess_sample(
211211 )
212212
213213
214- def _load_streaming_dataset (
214+ def get_dataset_samples (
215215 dataset_name : str ,
216+ num_samples : int ,
216217 * ,
217218 apply_chat_template : bool = False ,
218219 tokenizer : "PreTrainedTokenizerBase | None" = None ,
219220 split : str | list [str ] | None = None ,
220- ) -> tuple [ list , Callable [[ dict ], str ] ]:
221- """Resolve dataset config and return streaming splits with a preprocessing function .
221+ ) -> list [ str ]:
222+ """Load a portion of a dataset with the dataset name and a given size .
222223
223- This is a shared helper for :func:`get_dataset_samples`.
224+ Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
225+ HuggingFace datasets. Unregistered datasets are auto-detected by column names:
226+ ``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
227+
228+ Args:
229+ dataset_name: Name or HuggingFace path of the dataset to load, a local directory path,
230+ or a path to a ``.jsonl`` file. For local directory paths, the
231+ predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
232+ matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
233+ num_samples: Number of samples to load from the dataset.
234+ apply_chat_template: Whether to apply the chat template to the samples
235+ (if supported by the dataset). For unregistered datasets with a
236+ ``messages`` column, chat template is always applied regardless of
237+ this flag.
238+ tokenizer: Tokenizer to use for applying the chat template to the samples.
239+ No tokenization is done and plain text is still returned.
240+ split: Override the split(s) to load. Accepts a single split name or a list.
241+ If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
242+ registered datasets, or ``["train"]`` for unregistered datasets.
224243
225244 Returns:
226- A tuple of ``(dataset_splits, preprocess_fn)`` where *dataset_splits* is a list of
227- HuggingFace ``IterableDataset`` objects and *preprocess_fn* maps a raw sample dict
228- to a plain-text string (empty string signals a sample to skip).
245+ Samples: The list of samples.
229246 """
247+ # Local JSONL file path support (each line is a JSON object with a `text` field).
248+ if dataset_name .endswith (".jsonl" ):
249+ return get_jsonl_text_samples (dataset_name , num_samples , key = "text" )
250+
230251 from datasets import load_dataset
231252
232253 local_dataset_path = None
@@ -280,66 +301,17 @@ def _preprocess(sample: dict) -> str:
280301 print (f"Loading dataset with { config = } and { splits = } " )
281302 dataset_splits = [load_dataset (streaming = True , ** config , split = s ) for s in splits ]
282303
283- return dataset_splits , _preprocess
284-
285-
286- def get_dataset_samples (
287- dataset_name : str ,
288- num_samples : int ,
289- * ,
290- apply_chat_template : bool = False ,
291- tokenizer : "PreTrainedTokenizerBase | None" = None ,
292- split : str | list [str ] | None = None ,
293- ) -> list [str ]:
294- """Load a portion of a dataset with the dataset name and a given size.
295-
296- Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
297- HuggingFace datasets. Unregistered datasets are auto-detected by column names:
298- ``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
299-
300- Args:
301- dataset_name: Name or HuggingFace path of the dataset to load, a local directory path,
302- or a path to a ``.jsonl`` file. For local directory paths, the
303- predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
304- matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
305- num_samples: Number of samples to load from the dataset.
306- apply_chat_template: Whether to apply the chat template to the samples
307- (if supported by the dataset). For unregistered datasets with a
308- ``messages`` column, chat template is always applied regardless of
309- this flag.
310- tokenizer: Tokenizer to use for applying the chat template to the samples.
311- No tokenization is done and plain text is still returned.
312- split: Override the split(s) to load. Accepts a single split name or a list.
313- If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
314- registered datasets, or ``["train"]`` for unregistered datasets.
315-
316- Returns:
317- Samples: The list of samples.
318- """
319- # Local JSONL file path support (each line is a JSON object with a `text` field).
320- if dataset_name .endswith (".jsonl" ):
321- return get_jsonl_text_samples (dataset_name , num_samples , key = "text" )
322-
323- dataset_splits , _preprocess = _load_streaming_dataset (
324- dataset_name ,
325- apply_chat_template = apply_chat_template ,
326- tokenizer = tokenizer ,
327- split = split ,
328- )
329-
330304 num_per_split = [num_samples // len (dataset_splits )] * len (dataset_splits )
331305 num_per_split [- 1 ] += num_samples - sum (num_per_split )
332306
333307 samples : list [str ] = []
334308 for dataset , n in zip (dataset_splits , num_per_split ):
335- split_samples : list [str ] = []
336- for sample in dataset :
337- if len (split_samples ) >= n :
309+ for i , sample in enumerate (dataset ):
310+ if i >= n :
338311 break
339312 text = _preprocess (sample )
340313 if text :
341- split_samples .append (text )
342- samples .extend (split_samples )
314+ samples .append (text )
343315
344316 return samples
345317
0 commit comments