3131 DataHandler ,
3232 DataHandlerType ,
3333)
34- from tuning .utils .utils import (
34+ from tuning .data .utils import (
3535 get_loader_for_filepath ,
36+ maybe_align_datasets ,
3637 resolve_iterable_dataset_features ,
37- validate_mergeable_datasets ,
3838)
3939
4040logger = logging .getLogger (__name__ )
@@ -223,31 +223,30 @@ def _try_load_dataset(dataset_path, dataset_builder, streaming):
223223
224224 for data_path in data_paths :
225225 dataset = _try_load_dataset (data_path , builder , streaming )
226- if isinstance (dataset , IterableDataset ):
227- dataset = resolve_iterable_dataset_features (dataset )
228226 all_datasets .append (dataset )
229227
230- # Logs warning if datasets have different columns
231- validate_mergeable_datasets (all_datasets )
232-
233228 # Concatenate all datasets
234229 try :
235230 if len (all_datasets ) == 1 :
236231 return all_datasets [0 ]
237-
232+ maybe_align_datasets ( all_datasets )
238233 raw_datasets = datasets .concatenate_datasets (all_datasets )
239234 logger .info (
240- "Datasets concatenated from %s .Concatenated dataset columns : %s" ,
235+ "Datasets %s concatenated. Final column features : %s" ,
241236 datasetconfig .name ,
242- list (raw_datasets .features . keys ( )),
237+ str ( list (raw_datasets .features )),
243238 )
244- return raw_datasets
245-
246239 except Exception as e :
247240 raise ValueError (
248241 f"An error occurred while concatenating datasets from { datasetconfig .name } : { e } "
249242 ) from e
250243
244+ # Need to resolve dataset features because data handlers use columns.
245+ if isinstance (raw_datasets , IterableDataset ):
246+ raw_datasets = resolve_iterable_dataset_features (raw_datasets )
247+
248+ return raw_datasets
249+
251250 def __execute_rename_data_handler (self , raw_datasets , handler , ** kwargs ):
252251 """
253252 Rename columns in the dataset using the provided column mapping.
@@ -456,9 +455,6 @@ def _process_dataset_configs(
456455 raw_dataset = self .load_dataset (
457456 d , self .processor_config .streaming , splitName
458457 )
459- if isinstance (raw_dataset , IterableDataset ):
460- raw_dataset = resolve_iterable_dataset_features (raw_dataset )
461-
462458 logger .info ("Loaded raw dataset : %s" , str (raw_dataset ))
463459
464460 if isinstance (raw_dataset , IterableDataset ):
@@ -493,6 +489,9 @@ def _process_dataset_configs(
493489 else :
494490 final_datasets [k ].append (v )
495491
492+ # Ensure again datasets are aligned before interleaving or concatenating
493+ maybe_align_datasets (final_datasets )
494+
496495 if sample_datasets :
497496 strategy = self .processor_config .sampling_stopping_strategy
498497 seed = self .processor_config .sampling_seed
@@ -517,6 +516,8 @@ def _process_dataset_configs(
517516 )
518517
519518 train_dataset = final_datasets .get ("train" , None )
519+
520+ # Just a failsafe in case this is required later.
520521 if isinstance (train_dataset , IterableDataset ):
521522 train_dataset = resolve_iterable_dataset_features (train_dataset )
522523
0 commit comments