Skip to content

Commit 6ededde

Browse files
Make dataset_utils logic generic
1 parent 560b503 commit 6ededde

1 file changed

Lines changed: 123 additions & 112 deletions

File tree

modelopt/torch/utils/dataset_utils.py

Lines changed: 123 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -110,67 +110,60 @@
110110
]
111111

112112

113-
def _third_party_get_dataset_samples(
114-
dataset_name: str, num_samples: int, tokenizer: "PreTrainedTokenizerBase | None"
115-
) -> list[str]:
116-
"""Load a third-party dataset with the given name and number of samples.
113+
def _normalize_splits(split: str | list[str]) -> list[str]:
114+
"""Ensure split is always a list."""
115+
return [split] if isinstance(split, str) else list(split)
117116

118-
for messages: apply_chat_template is applied as needed.
119-
for text: no tokenization is done and plain text is still returned.
120-
"""
121-
warn(
122-
f"Loading third-party dataset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}."
123-
)
124-
from datasets import load_dataset
125117

126-
dataset = load_dataset(
127-
dataset_name,
128-
streaming=True,
129-
split="train",
130-
)
131-
dataset = dataset.shuffle(seed=42, buffer_size=10000).take(num_samples)
132-
texts = []
133-
if "messages" in dataset.column_names:
134-
if tokenizer is None:
135-
raise ValueError(
136-
f"Your dataset {dataset_name} has a `messages` column, but no tokenizer was provided. Are you sure you are using a tokenizer that supports chat templates?"
137-
)
138-
if not hasattr(tokenizer, "apply_chat_template"):
118+
def _auto_preprocess_sample(
119+
sample: dict,
120+
dataset_name: str,
121+
tokenizer: "PreTrainedTokenizerBase | None" = None,
122+
) -> str:
123+
"""Auto-detect dataset format and preprocess a single sample based on column conventions.
124+
125+
Column detection order (first match wins):
126+
1. ``messages`` / ``conversations`` -> ``tokenizer.apply_chat_template`` (with ``tools`` if present)
127+
2. ``prompt`` (+ optional ``completion`` / ``response`` / ``output``) -> concatenate
128+
3. ``text`` -> use as-is
129+
4. ``input`` (+ optional ``output``) -> concatenate
130+
131+
Raises:
132+
ValueError: If the tokenizer is missing/incompatible for chat-format datasets,
133+
or if no recognized column is found.
134+
"""
135+
chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None)
136+
if chat_key is not None:
137+
if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"):
139138
raise ValueError(
140-
f"Your dataset {dataset_name} has a `messages` column, but the tokenizer does not have an `apply_chat_template` method. Are you sure you are using a tokenizer that supports chat templates?"
139+
f"Dataset '{dataset_name}' has a '{chat_key}' column but no tokenizer with "
140+
"apply_chat_template was provided."
141141
)
142-
texts = []
143-
print(
144-
f"Using dataset with columns of {dataset_name}: messages and tools to apply chat template."
145-
)
146-
for i, sample in enumerate(dataset):
147-
messages = sample.get("messages", [])
148-
kwargs = {}
149-
tools = sample.get("tools", [])
150-
if tools:
151-
kwargs["tools"] = tools
152-
if not messages:
153-
raise ValueError(
154-
f"Row {i} in dataset {dataset_name} has no messages, or a empty messages."
155-
)
156-
text: str = tokenizer.apply_chat_template(messages, **kwargs, tokenize=False)
157-
if len(text) == 0:
158-
raise ValueError(
159-
f"Row {i} in dataset {dataset_name} has empty text after applying chat template."
160-
)
161-
texts.append(text)
162-
elif "prompt" in dataset.column_names:
163-
texts = [sample["prompt"] for sample in dataset]
164-
elif "text" in dataset.column_names:
165-
texts = [sample["text"] for sample in dataset]
166-
else:
167-
raise NotImplementedError(
168-
f"Dataset {dataset_name} is not supported. Please use one of the following: {get_supported_datasets()}. "
169-
" For supporting third-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split."
170-
" For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split."
171-
)
172-
173-
return texts
142+
kwargs: dict[str, Any] = {}
143+
tools = sample.get("tools")
144+
if tools:
145+
kwargs["tools"] = tools
146+
return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs)
147+
148+
if "prompt" in sample:
149+
parts = [sample["prompt"]]
150+
parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k))
151+
return "\n".join(parts)
152+
153+
if "text" in sample:
154+
return sample["text"]
155+
156+
if "input" in sample:
157+
parts = [sample["input"]]
158+
if sample.get("output"):
159+
parts.append(sample["output"])
160+
return "\n".join(parts)
161+
162+
raise ValueError(
163+
f"Cannot auto-detect format for dataset '{dataset_name}'. "
164+
f"Found columns: {list(sample.keys())}. "
165+
"Expected one of: 'messages', 'conversations', 'prompt', 'text', or 'input'."
166+
)
174167

175168

176169
def get_dataset_samples(
@@ -179,73 +172,86 @@ def get_dataset_samples(
179172
*,
180173
apply_chat_template: bool = False,
181174
tokenizer: "PreTrainedTokenizerBase | None" = None,
175+
split: str | list[str] | None = None,
182176
) -> list[str]:
183-
"""Load a portion of train dataset with the dataset name and a given size.
177+
"""Load a portion of a dataset with the dataset name and a given size.
178+
179+
Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
180+
HuggingFace datasets. Unregistered datasets are auto-detected by column names:
181+
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
184182
185183
Args:
186-
dataset_name: Name of the dataset to load.
184+
dataset_name: Name or HuggingFace path of the dataset to load.
187185
num_samples: Number of samples to load from the dataset.
188-
apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset).
186+
apply_chat_template: Whether to apply the chat template to the samples
187+
(if supported by the dataset). For unregistered datasets with a
188+
``messages`` column, chat template is always applied regardless of
189+
this flag.
189190
tokenizer: Tokenizer to use for applying the chat template to the samples.
190191
No tokenization is done and plain text is still returned.
192+
split: Override the split(s) to load. Accepts a single split name or a list.
193+
If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
194+
registered datasets, or ``["train"]`` for unregistered datasets.
191195
192196
Returns:
193197
Samples: The list of samples.
194198
"""
195-
# Load the dataset
196-
if dataset_name not in SUPPORTED_DATASET_CONFIG:
199+
from datasets import load_dataset
200+
201+
is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
202+
203+
if is_registered:
204+
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
205+
config = dataset_config["config"].copy()
206+
splits = _normalize_splits(split) if split is not None else config.pop("split", [None])
207+
if split is not None:
208+
config.pop("split", None)
209+
210+
if apply_chat_template:
211+
if "chat_key" not in dataset_config:
212+
warn(
213+
f"Dataset {dataset_name} does not support chat template."
214+
" Chat template will not be applied."
215+
)
216+
elif tokenizer is None:
217+
raise ValueError("Tokenizer is required when applying chat template.")
218+
219+
def _preprocess(sample: dict) -> str:
220+
if apply_chat_template and "chat_key" in dataset_config:
221+
kwargs: dict[str, Any] = {}
222+
tools = sample.get("tools")
223+
if tools:
224+
kwargs["tools"] = tools
225+
return tokenizer.apply_chat_template( # type: ignore[union-attr]
226+
sample[dataset_config["chat_key"]], tokenize=False, **kwargs
227+
)
228+
return dataset_config["preprocess"](sample)
229+
230+
else:
197231
warn(
198-
f"dataset {dataset_name} is not supported. Please use one of the following:"
199-
f" {get_supported_datasets()}."
200-
" Trying to set up via third-party datasets."
232+
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
233+
"Auto-detecting format from column names."
201234
)
202-
return _third_party_get_dataset_samples(dataset_name, num_samples, tokenizer=tokenizer)
235+
config = {"path": dataset_name}
236+
splits = _normalize_splits(split) if split is not None else ["train"]
203237

204-
from datasets import load_dataset
238+
def _preprocess(sample: dict) -> str:
239+
return _auto_preprocess_sample(sample, dataset_name, tokenizer)
205240

206-
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
207-
if apply_chat_template:
208-
if "chat_key" not in dataset_config:
209-
warn(
210-
f"Dataset {dataset_name} does not support chat template. Chat template will not be applied."
211-
)
212-
elif tokenizer is None:
213-
raise ValueError("Tokenizer is required when applying chat template.")
214-
print(f"Applying chat template to dataset {dataset_name}")
215-
216-
# It's unfortunate that the load_dataset function does not support split a list while streaming.
217-
# So we need to load the dataset for each split.
218-
config = dataset_config["config"].copy()
219-
splits = config.pop("split", [None])
220-
dataset_splits = [
221-
load_dataset(
222-
streaming=True,
223-
**config,
224-
split=split,
225-
)
226-
for split in splits
227-
]
228-
229-
# Split the samples evenly across the splits
230-
# For streaming datasets, there is no reliable way to get the number of samples in each split
231-
# other than loading the entire dataset. So, we just use the same number of samples for each split.
232-
num_samples_splits = [num_samples // len(dataset_splits) for _ in dataset_splits]
233-
num_samples_splits[-1] += num_samples - sum(num_samples_splits)
234-
samples = []
235-
for dataset, num_samples_split in zip(dataset_splits, num_samples_splits):
241+
# load_dataset does not support a list of splits while streaming, so load each separately.
242+
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]
243+
244+
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
245+
num_per_split[-1] += num_samples - sum(num_per_split)
246+
247+
samples: list[str] = []
248+
for dataset, n in zip(dataset_splits, num_per_split):
236249
for i, sample in enumerate(dataset):
237-
if i >= num_samples_split:
250+
if i >= n:
238251
break
239-
240-
# Apply preprocess function to the sample
241-
if apply_chat_template and "chat_key" in dataset_config:
242-
sample = tokenizer.apply_chat_template( # type: ignore[union-attr]
243-
sample[dataset_config["chat_key"]], tokenize=False
244-
)
245-
else:
246-
sample = dataset_config["preprocess"](sample)
247-
if sample != "": # wikitext has some empty samples
248-
samples.append(sample)
252+
text = _preprocess(sample)
253+
if text:
254+
samples.append(text)
249255

250256
return samples
251257

@@ -273,20 +279,23 @@ def get_dataset_dataloader(
273279
max_sample_length: int = 512,
274280
device: torch.device | None = None,
275281
include_labels: bool = False,
282+
apply_chat_template: bool = False,
276283
) -> DataLoader:
277-
"""Get a dataloader with the dataset name and toknizer of the target model.
284+
"""Get a dataloader with the dataset name and tokenizer of the target model.
278285
279286
Args:
280287
dataset_name: Name of the dataset to load.
281-
tokenizer: Instancne of Hugginface tokenizer.
288+
tokenizer: Instance of HuggingFace tokenizer.
282289
batch_size: Batch size of the returned dataloader.
283290
num_samples: Number of samples from the dataset.
284291
max_sample_length: Maximum length of a sample.
285292
device: Target device for the returned dataloader.
286293
include_labels: Whether to include labels in the dataloader.
294+
apply_chat_template: Whether to apply the chat template to the samples
295+
(if supported by the dataset).
287296
288297
Returns:
289-
A instance of dataloader.
298+
An instance of dataloader.
290299
"""
291300
assert tokenizer is not None, "Please provide a tokenizer."
292301
# batch_encode_plus will modify the tokenizer in place, so we need to clone it.
@@ -309,7 +318,9 @@ def get_dataset_dataloader(
309318

310319
all_samples = []
311320
for ds_name, num_sample in zip(dataset_name, num_samples):
312-
samples = get_dataset_samples(ds_name, num_sample, tokenizer=tokenizer)
321+
samples = get_dataset_samples(
322+
ds_name, num_sample, apply_chat_template=apply_chat_template, tokenizer=tokenizer
323+
)
313324
all_samples.extend(samples)
314325

315326
batch_encoded = tokenizer.batch_encode_plus(

0 commit comments

Comments
 (0)