110110]
111111
112112
113- def _third_party_get_dataset_samples (
114- dataset_name : str , num_samples : int , tokenizer : "PreTrainedTokenizerBase | None"
115- ) -> list [str ]:
116- """Load a third-party dataset with the given name and number of samples.
113+ def _normalize_splits (split : str | list [str ]) -> list [str ]:
114+ """Ensure split is always a list."""
115+ return [split ] if isinstance (split , str ) else list (split )
117116
118- for messages: apply_chat_template is applied as needed.
119- for text: no tokenization is done and plain text is still returned.
120- """
121- warn (
122- f"Loading third-party dataset { dataset_name } with the split `train`, as the dataset is not registered in { get_supported_datasets ()} ."
123- )
124- from datasets import load_dataset
125117
126- dataset = load_dataset (
127- dataset_name ,
128- streaming = True ,
129- split = "train" ,
130- )
131- dataset = dataset .shuffle (seed = 42 , buffer_size = 10000 ).take (num_samples )
132- texts = []
133- if "messages" in dataset .column_names :
134- if tokenizer is None :
135- raise ValueError (
136- f"Your dataset { dataset_name } has a `messages` column, but no tokenizer was provided. Are you sure you are using a tokenizer that supports chat templates?"
137- )
138- if not hasattr (tokenizer , "apply_chat_template" ):
118+ def _auto_preprocess_sample (
119+ sample : dict ,
120+ dataset_name : str ,
121+ tokenizer : "PreTrainedTokenizerBase | None" = None ,
122+ ) -> str :
123+ """Auto-detect dataset format and preprocess a single sample based on column conventions.
124+
125+ Column detection order (first match wins):
126+ 1. ``messages`` / ``conversations`` -> ``tokenizer.apply_chat_template`` (with ``tools`` if present)
127+ 2. ``prompt`` (+ optional ``completion`` / ``response`` / ``output``) -> concatenate
128+ 3. ``text`` -> use as-is
129+ 4. ``input`` (+ optional ``output``) -> concatenate
130+
131+ Raises:
132+ ValueError: If the tokenizer is missing/incompatible for chat-format datasets,
133+ or if no recognized column is found.
134+ """
135+ chat_key = next ((k for k in ("messages" , "conversations" ) if sample .get (k )), None )
136+ if chat_key is not None :
137+ if tokenizer is None or not hasattr (tokenizer , "apply_chat_template" ):
139138 raise ValueError (
140- f"Your dataset { dataset_name } has a `messages` column, but the tokenizer does not have an `apply_chat_template` method. Are you sure you are using a tokenizer that supports chat templates?"
139+ f"Dataset '{ dataset_name } ' has a '{ chat_key } ' column but no tokenizer with "
140+ "apply_chat_template was provided."
141141 )
142- texts = []
143- print (
144- f"Using dataset with columns of { dataset_name } : messages and tools to apply chat template."
145- )
146- for i , sample in enumerate (dataset ):
147- messages = sample .get ("messages" , [])
148- kwargs = {}
149- tools = sample .get ("tools" , [])
150- if tools :
151- kwargs ["tools" ] = tools
152- if not messages :
153- raise ValueError (
154- f"Row { i } in dataset { dataset_name } has no messages, or a empty messages."
155- )
156- text : str = tokenizer .apply_chat_template (messages , ** kwargs , tokenize = False )
157- if len (text ) == 0 :
158- raise ValueError (
159- f"Row { i } in dataset { dataset_name } has empty text after applying chat template."
160- )
161- texts .append (text )
162- elif "prompt" in dataset .column_names :
163- texts = [sample ["prompt" ] for sample in dataset ]
164- elif "text" in dataset .column_names :
165- texts = [sample ["text" ] for sample in dataset ]
166- else :
167- raise NotImplementedError (
168- f"Dataset { dataset_name } is not supported. Please use one of the following: { get_supported_datasets ()} . "
169- " For supporting third-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split."
170- " For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split."
171- )
172-
173- return texts
142+ kwargs : dict [str , Any ] = {}
143+ tools = sample .get ("tools" )
144+ if tools :
145+ kwargs ["tools" ] = tools
146+ return tokenizer .apply_chat_template (sample [chat_key ], tokenize = False , ** kwargs )
147+
148+ if "prompt" in sample :
149+ parts = [sample ["prompt" ]]
150+ parts .extend (sample [k ] for k in ("completion" , "response" , "output" ) if sample .get (k ))
151+ return "\n " .join (parts )
152+
153+ if "text" in sample :
154+ return sample ["text" ]
155+
156+ if "input" in sample :
157+ parts = [sample ["input" ]]
158+ if sample .get ("output" ):
159+ parts .append (sample ["output" ])
160+ return "\n " .join (parts )
161+
162+ raise ValueError (
163+ f"Cannot auto-detect format for dataset '{ dataset_name } '. "
164+ f"Found columns: { list (sample .keys ())} . "
165+ "Expected one of: 'messages', 'conversations', 'prompt', 'text', or 'input'."
166+ )
174167
175168
176169def get_dataset_samples (
@@ -179,73 +172,86 @@ def get_dataset_samples(
179172 * ,
180173 apply_chat_template : bool = False ,
181174 tokenizer : "PreTrainedTokenizerBase | None" = None ,
175+ split : str | list [str ] | None = None ,
182176) -> list [str ]:
183- """Load a portion of train dataset with the dataset name and a given size.
177+ """Load a portion of a dataset with the dataset name and a given size.
178+
179+ Supports both registered datasets (in ``SUPPORTED_DATASET_CONFIG``) and arbitrary
180+ HuggingFace datasets. Unregistered datasets are auto-detected by column names:
181+ ``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
184182
185183 Args:
186- dataset_name: Name of the dataset to load.
184+ dataset_name: Name or HuggingFace path of the dataset to load.
187185 num_samples: Number of samples to load from the dataset.
188- apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset).
186+ apply_chat_template: Whether to apply the chat template to the samples
187+ (if supported by the dataset). For unregistered datasets with a
188+ ``messages`` column, chat template is always applied regardless of
189+ this flag.
189190 tokenizer: Tokenizer to use for applying the chat template to the samples.
190191 No tokenization is done and plain text is still returned.
192+ split: Override the split(s) to load. Accepts a single split name or a list.
193+ If ``None``, uses the splits defined in ``SUPPORTED_DATASET_CONFIG`` for
194+ registered datasets, or ``["train"]`` for unregistered datasets.
191195
192196 Returns:
193197 Samples: The list of samples.
194198 """
195- # Load the dataset
196- if dataset_name not in SUPPORTED_DATASET_CONFIG :
199+ from datasets import load_dataset
200+
201+ is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
202+
203+ if is_registered :
204+ dataset_config = SUPPORTED_DATASET_CONFIG [dataset_name ]
205+ config = dataset_config ["config" ].copy ()
206+ splits = _normalize_splits (split ) if split is not None else config .pop ("split" , [None ])
207+ if split is not None :
208+ config .pop ("split" , None )
209+
210+ if apply_chat_template :
211+ if "chat_key" not in dataset_config :
212+ warn (
213+ f"Dataset { dataset_name } does not support chat template."
214+ " Chat template will not be applied."
215+ )
216+ elif tokenizer is None :
217+ raise ValueError ("Tokenizer is required when applying chat template." )
218+
219+ def _preprocess (sample : dict ) -> str :
220+ if apply_chat_template and "chat_key" in dataset_config :
221+ kwargs : dict [str , Any ] = {}
222+ tools = sample .get ("tools" )
223+ if tools :
224+ kwargs ["tools" ] = tools
225+ return tokenizer .apply_chat_template ( # type: ignore[union-attr]
226+ sample [dataset_config ["chat_key" ]], tokenize = False , ** kwargs
227+ )
228+ return dataset_config ["preprocess" ](sample )
229+
230+ else :
197231 warn (
198- f"dataset { dataset_name } is not supported. Please use one of the following:"
199- f" { get_supported_datasets ()} ."
200- " Trying to set up via third-party datasets."
232+ f"Dataset '{ dataset_name } ' is not in SUPPORTED_DATASET_CONFIG. "
233+ "Auto-detecting format from column names."
201234 )
202- return _third_party_get_dataset_samples (dataset_name , num_samples , tokenizer = tokenizer )
235+ config = {"path" : dataset_name }
236+ splits = _normalize_splits (split ) if split is not None else ["train" ]
203237
204- from datasets import load_dataset
238+ def _preprocess (sample : dict ) -> str :
239+ return _auto_preprocess_sample (sample , dataset_name , tokenizer )
205240
206- dataset_config = SUPPORTED_DATASET_CONFIG [dataset_name ]
207- if apply_chat_template :
208- if "chat_key" not in dataset_config :
209- warn (
210- f"Dataset { dataset_name } does not support chat template. Chat template will not be applied."
211- )
212- elif tokenizer is None :
213- raise ValueError ("Tokenizer is required when applying chat template." )
214- print (f"Applying chat template to dataset { dataset_name } " )
215-
216- # It's unfortunate that the load_dataset function does not support split a list while streaming.
217- # So we need to load the dataset for each split.
218- config = dataset_config ["config" ].copy ()
219- splits = config .pop ("split" , [None ])
220- dataset_splits = [
221- load_dataset (
222- streaming = True ,
223- ** config ,
224- split = split ,
225- )
226- for split in splits
227- ]
228-
229- # Split the samples evenly across the splits
230- # For streaming datasets, there is no reliable way to get the number of samples in each split
231- # other than loading the entire dataset. So, we just use the same number of samples for each split.
232- num_samples_splits = [num_samples // len (dataset_splits ) for _ in dataset_splits ]
233- num_samples_splits [- 1 ] += num_samples - sum (num_samples_splits )
234- samples = []
235- for dataset , num_samples_split in zip (dataset_splits , num_samples_splits ):
241+ # load_dataset does not support a list of splits while streaming, so load each separately.
242+ dataset_splits = [load_dataset (streaming = True , ** config , split = s ) for s in splits ]
243+
244+ num_per_split = [num_samples // len (dataset_splits )] * len (dataset_splits )
245+ num_per_split [- 1 ] += num_samples - sum (num_per_split )
246+
247+ samples : list [str ] = []
248+ for dataset , n in zip (dataset_splits , num_per_split ):
236249 for i , sample in enumerate (dataset ):
237- if i >= num_samples_split :
250+ if i >= n :
238251 break
239-
240- # Apply preprocess function to the sample
241- if apply_chat_template and "chat_key" in dataset_config :
242- sample = tokenizer .apply_chat_template ( # type: ignore[union-attr]
243- sample [dataset_config ["chat_key" ]], tokenize = False
244- )
245- else :
246- sample = dataset_config ["preprocess" ](sample )
247- if sample != "" : # wikitext has some empty samples
248- samples .append (sample )
252+ text = _preprocess (sample )
253+ if text :
254+ samples .append (text )
249255
250256 return samples
251257
@@ -273,20 +279,23 @@ def get_dataset_dataloader(
273279 max_sample_length : int = 512 ,
274280 device : torch .device | None = None ,
275281 include_labels : bool = False ,
282+ apply_chat_template : bool = False ,
276283) -> DataLoader :
277- """Get a dataloader with the dataset name and toknizer of the target model.
284+ """Get a dataloader with the dataset name and tokenizer of the target model.
278285
279286 Args:
280287 dataset_name: Name of the dataset to load.
281- tokenizer: Instancne of Hugginface tokenizer.
288+ tokenizer: Instance of HuggingFace tokenizer.
282289 batch_size: Batch size of the returned dataloader.
283290 num_samples: Number of samples from the dataset.
284291 max_sample_length: Maximum length of a sample.
285292 device: Target device for the returned dataloader.
286293 include_labels: Whether to include labels in the dataloader.
294+ apply_chat_template: Whether to apply the chat template to the samples
295+ (if supported by the dataset).
287296
288297 Returns:
289- A instance of dataloader.
298+ An instance of dataloader.
290299 """
291300 assert tokenizer is not None , "Please provide a tokenizer."
292301 # batch_encode_plus will modify the tokenizer in place, so we need to clone it.
@@ -309,7 +318,9 @@ def get_dataset_dataloader(
309318
310319 all_samples = []
311320 for ds_name , num_sample in zip (dataset_name , num_samples ):
312- samples = get_dataset_samples (ds_name , num_sample , tokenizer = tokenizer )
321+ samples = get_dataset_samples (
322+ ds_name , num_sample , apply_chat_template = apply_chat_template , tokenizer = tokenizer
323+ )
313324 all_samples .extend (samples )
314325
315326 batch_encoded = tokenizer .batch_encode_plus (
0 commit comments