Skip to content

Commit a5eb313

Browse files
committed
third-party-dataset support, for shuffle
Signed-off-by: michaelfeil <63565275+michaelfeil@users.noreply.github.com>
1 parent e024097 commit a5eb313

1 file changed

Lines changed: 67 additions & 2 deletions

File tree

modelopt/torch/utils/dataset_utils.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,69 @@
110110
]
111111

112112

113+
def _third_party_get_dataset_samples(
114+
dataset_name: str, num_samples: int, tokenizer: "PreTrainedTokenizerBase | None"
115+
) -> list[str]:
116+
"""Load a third-party dataset with the given name and number of samples.
117+
118+
for messages: apply_chat_template is applied as needed.
119+
for text: no tokenization is done and plain text is still returned.
120+
"""
121+
warn(
122+
f"Loading third-party datset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}."
123+
)
124+
from datasets import load_dataset
125+
126+
dataset = load_dataset(
127+
dataset_name,
128+
streaming=True,
129+
split="train",
130+
)
131+
dataset = dataset.shuffle(seed=42, buffer_size=10000).take(num_samples)
132+
texts = []
133+
if "messages" in dataset.column_names:
134+
if tokenizer is None:
135+
raise ValueError(
136+
f"Your dataset {dataset_name} has a `messages` column, but no tokenizer was provided. Are you sure you are using a tokenizer that supports chat templates?"
137+
)
138+
if not hasattr(tokenizer, "apply_chat_template"):
139+
raise ValueError(
140+
f"Your dataset {dataset_name} has a `messages` column, but the tokenizer does not have an `apply_chat_template` method. Are you sure you are using a tokenizer that supports chat templates?"
141+
)
142+
texts = []
143+
print(
144+
f"Using dataset with columns of {dataset_name}: messages and tools to apply chat template."
145+
)
146+
for i, sample in enumerate(dataset):
147+
messages = sample.get("messages", [])
148+
kwargs = {}
149+
tools = sample.get("tools", [])
150+
if tools:
151+
kwargs["tools"] = tools
152+
if not messages:
153+
raise ValueError(
154+
f"Column {i} in dataset {dataset_name} has no messages, or a empty messages."
155+
)
156+
text: str = tokenizer.apply_chat_template(messages, **kwargs, tokenize=False)
157+
if len(text) == 0:
158+
raise ValueError(
159+
f"Column {i} in dataset {dataset_name} has empty text after applying chat template."
160+
)
161+
texts.append(text)
162+
elif "prompt" in dataset.column_names:
163+
texts = [sample["prompt"] for sample in dataset]
164+
elif "text" in dataset.column_names:
165+
texts = [sample["text"] for sample in dataset]
166+
else:
167+
raise NotImplementedError(
168+
f"Dataset {dataset_name} is not supported. Please use one of the following: {get_supported_datasets()}. "
169+
" For supporting thrid-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split."
170+
" For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split."
171+
)
172+
173+
return texts
174+
175+
113176
def get_dataset_samples(
114177
dataset_name: str,
115178
num_samples: int,
@@ -131,10 +194,12 @@ def get_dataset_samples(
131194
"""
132195
# Load the dataset
133196
if dataset_name not in SUPPORTED_DATASET_CONFIG:
134-
raise NotImplementedError(
197+
warn(
135198
f"dataset {dataset_name} is not supported. Please use one of the following:"
136199
f" {get_supported_datasets()}."
200+
" Trying to set up via third-party datasets."
137201
)
202+
return _third_party_get_dataset_samples(dataset_name, num_samples, tokenizer=tokenizer)
138203

139204
from datasets import load_dataset
140205

@@ -244,7 +309,7 @@ def get_dataset_dataloader(
244309

245310
all_samples = []
246311
for ds_name, num_sample in zip(dataset_name, num_samples):
247-
samples = get_dataset_samples(ds_name, num_sample)
312+
samples = get_dataset_samples(ds_name, num_sample, tokenizer=tokenizer)
248313
all_samples.extend(samples)
249314

250315
batch_encoded = tokenizer.batch_encode_plus(

0 commit comments

Comments
 (0)