Skip to content

Commit 7a29dc0

Browse files
committed
reverted dataset_utils
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 1c01e02 commit 7a29dc0

File tree

1 file changed

+31
-59
lines changed

1 file changed

+31
-59
lines changed

modelopt/torch/utils/dataset_utils.py

Lines changed: 31 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,43 @@ def _auto_preprocess_sample(
211211
)
212212

213213

214-
def _load_streaming_dataset(
214+
def get_dataset_samples(
215215
dataset_name: str,
216+
num_samples: int,
216217
*,
217218
apply_chat_template: bool = False,
218219
tokenizer: "PreTrainedTokenizerBase | None" = None,
219220
split: str | list[str] | None = None,
220-
) -> tuple[list, Callable[[dict], str]]:
221-
"""Resolve dataset config and return streaming splits with a preprocessing function.
221+
) -> list[str]:
222+
"""Load a portion of a dataset with the dataset name and a given size.
222223
223-
This is a shared helper for :func:`get_dataset_samples`.
224+
Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
225+
HuggingFace datasets. Unregistered datasets are auto-detected by column names:
226+
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
227+
228+
Args:
229+
dataset_name: Name or HuggingFace path of the dataset to load, a local directory path,
230+
or a path to a ``.jsonl`` file. For local directory paths, the
231+
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
232+
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
233+
num_samples: Number of samples to load from the dataset.
234+
apply_chat_template: Whether to apply the chat template to the samples
235+
(if supported by the dataset). For unregistered datasets with a
236+
``messages`` column, chat template is always applied regardless of
237+
this flag.
238+
tokenizer: Tokenizer to use for applying the chat template to the samples.
239+
No tokenization is done and plain text is still returned.
240+
split: Override the split(s) to load. Accepts a single split name or a list.
241+
If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
242+
registered datasets, or ``["train"]`` for unregistered datasets.
224243
225244
Returns:
226-
A tuple of ``(dataset_splits, preprocess_fn)`` where *dataset_splits* is a list of
227-
HuggingFace ``IterableDataset`` objects and *preprocess_fn* maps a raw sample dict
228-
to a plain-text string (empty string signals a sample to skip).
245+
Samples: The list of samples.
229246
"""
247+
# Local JSONL file path support (each line is a JSON object with a `text` field).
248+
if dataset_name.endswith(".jsonl"):
249+
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
250+
230251
from datasets import load_dataset
231252

232253
local_dataset_path = None
@@ -280,66 +301,17 @@ def _preprocess(sample: dict) -> str:
280301
print(f"Loading dataset with {config=} and {splits=}")
281302
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]
282303

283-
return dataset_splits, _preprocess
284-
285-
286-
def get_dataset_samples(
287-
dataset_name: str,
288-
num_samples: int,
289-
*,
290-
apply_chat_template: bool = False,
291-
tokenizer: "PreTrainedTokenizerBase | None" = None,
292-
split: str | list[str] | None = None,
293-
) -> list[str]:
294-
"""Load a portion of a dataset with the dataset name and a given size.
295-
296-
Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
297-
HuggingFace datasets. Unregistered datasets are auto-detected by column names:
298-
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
299-
300-
Args:
301-
dataset_name: Name or HuggingFace path of the dataset to load, a local directory path,
302-
or a path to a ``.jsonl`` file. For local directory paths, the
303-
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
304-
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
305-
num_samples: Number of samples to load from the dataset.
306-
apply_chat_template: Whether to apply the chat template to the samples
307-
(if supported by the dataset). For unregistered datasets with a
308-
``messages`` column, chat template is always applied regardless of
309-
this flag.
310-
tokenizer: Tokenizer to use for applying the chat template to the samples.
311-
No tokenization is done and plain text is still returned.
312-
split: Override the split(s) to load. Accepts a single split name or a list.
313-
If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
314-
registered datasets, or ``["train"]`` for unregistered datasets.
315-
316-
Returns:
317-
Samples: The list of samples.
318-
"""
319-
# Local JSONL file path support (each line is a JSON object with a `text` field).
320-
if dataset_name.endswith(".jsonl"):
321-
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
322-
323-
dataset_splits, _preprocess = _load_streaming_dataset(
324-
dataset_name,
325-
apply_chat_template=apply_chat_template,
326-
tokenizer=tokenizer,
327-
split=split,
328-
)
329-
330304
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
331305
num_per_split[-1] += num_samples - sum(num_per_split)
332306

333307
samples: list[str] = []
334308
for dataset, n in zip(dataset_splits, num_per_split):
335-
split_samples: list[str] = []
336-
for sample in dataset:
337-
if len(split_samples) >= n:
309+
for i, sample in enumerate(dataset):
310+
if i >= n:
338311
break
339312
text = _preprocess(sample)
340313
if text:
341-
split_samples.append(text)
342-
samples.extend(split_samples)
314+
samples.append(text)
343315

344316
return samples
345317

0 commit comments

Comments
 (0)