From 933bca7231f2b2e10e20c57cd508e062fa495aba Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 8 Oct 2025 11:44:21 +0530 Subject: [PATCH 1/9] feat: resume functionality Signed-off-by: Mehant Kammakomati --- tuning/config/acceleration_configs/odm.py | 1 + tuning/sft_trainer.py | 35 ++++++++++++----------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/tuning/config/acceleration_configs/odm.py b/tuning/config/acceleration_configs/odm.py index f5c2d9f8ce..849c860f9a 100644 --- a/tuning/config/acceleration_configs/odm.py +++ b/tuning/config/acceleration_configs/odm.py @@ -27,6 +27,7 @@ class ODM: reward_type: str = None gamma: float = 0.1 eta: float = 0.1 + resume_from_checkpoint: Union[bool, str] = False @dataclass diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 04b920c8aa..7887c69c10 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -129,11 +129,29 @@ def train( logger_name="sft_trainer_train", level=train_args.log_level ) + resume_from_checkpoint = None + # Check if resume flag is not passed (None), or if flag is true and + # output_dir has checkpoints then get last checkpoint from output_dir + if ( + train_args.resume_from_checkpoint is None + or train_args.resume_from_checkpoint.lower() == "true" + ): + resume_from_checkpoint = get_last_checkpoint(train_args.output_dir) + else: + # `train_args.resume_from_checkpoint` gives string values + # Check if flag is false OR flag has checkpoint value for resuming tuning + resume_from_checkpoint = ( + train_args.resume_from_checkpoint + if train_args.resume_from_checkpoint.lower() != "false" + else False + ) + # TODO: use of load_and_validate_data_config here is not clean way # rather we should move this logic to process_dataargs odm_config = None if data_args.data_config_path: _dataconfig = load_and_validate_data_config(data_args.data_config_path) + _dataconfig.dataprocessor.odm["resume_from_checkpoint"] = resume_from_checkpoint if _dataconfig.dataprocessor.type == "odm": odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm)) @@ -504,23 +522,6 @@ def train( ): trainer.add_callback(clb) - resume_from_checkpoint = None - # Check if resume flag is not passed (None), or if flag is true and - # output_dir has checkpoints then get last checkpoint from output_dir - if ( - training_args.resume_from_checkpoint is None - or training_args.resume_from_checkpoint.lower() == "true" - ): - resume_from_checkpoint = get_last_checkpoint(training_args.output_dir) - else: - # `training_args.resume_from_checkpoint` gives string values - # Check if flag is false OR flag has checkpoint value for resuming tuning - resume_from_checkpoint = ( - training_args.resume_from_checkpoint - if training_args.resume_from_checkpoint.lower() != "false" - else False - ) - trainer.train(resume_from_checkpoint) additional_metadata = {} additional_metadata["added_tokens_info"] = added_tokens_dict From 3f09f096b8d624348ace73fbffec66bccb0afd4b Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 8 Oct 2025 11:50:35 +0530 Subject: [PATCH 2/9] feat: resume functionality Signed-off-by: Mehant Kammakomati --- tuning/config/acceleration_configs/odm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tuning/config/acceleration_configs/odm.py b/tuning/config/acceleration_configs/odm.py index 849c860f9a..497fc60481 100644 --- a/tuning/config/acceleration_configs/odm.py +++ b/tuning/config/acceleration_configs/odm.py @@ -14,6 +14,7 @@ # Standard from dataclasses import dataclass +from typing import Union # Local from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass From bbd5436575bb545f2105ea95e82dbee78bba2003 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 8 Oct 2025 12:03:19 +0530 Subject: [PATCH 3/9] feat: resume functionality Signed-off-by: Mehant Kammakomati --- tuning/sft_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 7887c69c10..f154af5018 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -151,8 +151,10 @@ def train( odm_config = None if data_args.data_config_path: _dataconfig = load_and_validate_data_config(data_args.data_config_path) - _dataconfig.dataprocessor.odm["resume_from_checkpoint"] = resume_from_checkpoint if _dataconfig.dataprocessor.type == "odm": + _dataconfig.dataprocessor.odm[ + "resume_from_checkpoint" + ] = resume_from_checkpoint odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm)) USE_ALORA = False From 2f404c6b0fd855b425255f4043e38d263085261e Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 8 Oct 2025 16:07:03 +0530 Subject: [PATCH 4/9] feat: resume functionality Signed-off-by: Mehant Kammakomati --- tuning/sft_trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index f154af5018..edafadb9ca 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -130,6 +130,10 @@ def train( ) resume_from_checkpoint = None + if training_args.output_dir: + os.makedirs(training_args.output_dir, exist_ok=True) + logger.info("using the output directory at %s", training_args.output_dir) + # Check if resume flag is not passed (None), or if flag is true and # output_dir has checkpoints then get last checkpoint from output_dir if ( @@ -797,9 +801,6 @@ def main(): "failed while parsing extra metadata. pass a valid json %s", repr(e) ) - if training_args.output_dir: - os.makedirs(training_args.output_dir, exist_ok=True) - logger.info("using the output directory at %s", training_args.output_dir) try: trainer, additional_train_info, tc_callback = train( model_args=model_args, From 82495d332f5e489267f9961bd574514175d76ad8 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 8 Oct 2025 16:31:43 +0530 Subject: [PATCH 5/9] feat: resume functionality Signed-off-by: Mehant Kammakomati --- tuning/data/data_processors.py | 141 +++++++++++++++++++---------- tuning/data/setup_dataprocessor.py | 8 +- tuning/sft_trainer.py | 6 +- 3 files changed, 99 insertions(+), 56 deletions(-) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index f5cc4b672a..d7784e1c31 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -452,35 +452,9 @@ def split_dataset( ) return split_datasets - def _process_datasets_for_odm( - self, - processed_datasets: List[ - Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]] - ], - ) -> Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ]: - train_split = "train" - eval_split = "test" - train_datasets_dict = {} - eval_datasets_dict = {} - for d, raw in processed_datasets: - if train_split in raw: - train_datasets_dict[d.name] = raw[train_split] - if eval_split in raw: - eval_datasets_dict[d.name] = raw[eval_split] - return train_datasets_dict, eval_datasets_dict - def _process_dataset_configs( - self, dataset_configs: List[DataSetConfig], odm_config=None - ) -> Union[ - Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]], - Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ], - ]: + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]]: if not dataset_configs: raise ValueError( @@ -530,13 +504,7 @@ def _process_dataset_configs( # Append the processed datasets to the final dict processed_datasets.append((d, raw_datasets)) - if odm_config: - logger.info( - "Sampling probabilities are ignored if provided" - "and are not used for concatenation. Instead" - "online data mixing plugin handles it." - ) - return self._process_datasets_for_odm(processed_datasets) + train_datasets = [] train_sampling_probabilities = [] validation_datasets = [] @@ -623,14 +591,8 @@ def _process_dataset_configs( return train_dataset, eval_dataset def process_dataset_configs( - self, dataset_configs: List[DataSetConfig], odm_config=None - ) -> Union[ - Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]], - Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ], - ]: + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]]: train_dataset = eval_dataset = None # Use partial state as recommended by HF documentation for process control @@ -643,9 +605,7 @@ def process_dataset_configs( # as we want to reuse HF cache and not redo computation on all nodes # For rationale see https://github.com/huggingface/trl/pull/3106 with state.main_process_first(): - train_dataset, eval_dataset = self._process_dataset_configs( - dataset_configs, odm_config - ) + train_dataset, eval_dataset = self._process_dataset_configs(dataset_configs) logger.info("Processed train dataset {}".format(train_dataset)) logger.info("Processed eval dataset {}".format(eval_dataset)) @@ -653,13 +613,100 @@ def process_dataset_configs( return train_dataset, eval_dataset +class ODMDataPreProcessor(DataPreProcessor): + def _process_datasets_for_odm( + self, + processed_datasets: List[ + Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]] + ], + ) -> Tuple[ + Dict[str, Union[Dataset, IterableDataset]], + Dict[str, Union[Dataset, IterableDataset]], + ]: + train_split = "train" + eval_split = "test" + train_datasets_dict = {} + eval_datasets_dict = {} + for d, raw in processed_datasets: + if train_split in raw: + train_datasets_dict[d.name] = raw[train_split] + if eval_split in raw: + eval_datasets_dict[d.name] = raw[eval_split] + return train_datasets_dict, eval_datasets_dict + + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[ + Dict[str, Union[Dataset, IterableDataset]], + Dict[str, Union[Dataset, IterableDataset]], + ]: + + if not dataset_configs: + raise ValueError( + "No dataset configs provided. Provided Dataset configs is None." + ) + + train_split = "train" # default + eval_split = "test" + + processed_datasets = [] + + logger.info("Starting DataPreProcessor...") + # Now Iterate over the multiple datasets provided to us to process + for d in dataset_configs: + logger.info("Loading the dataset - %s", d.name) + + # In future the streaming etc go as kwargs of this function + loaded_dataset = self.load_dataset(d, self.processor_config.streaming) + logger.info("Loaded raw dataset : %s", str(loaded_dataset)) + + if d.split is not None: + loaded_dataset = self.split_dataset(d, loaded_dataset) + + # Create a raw dataset which is a Dict container to house all Datasets + raw_datasets = ( + IterableDatasetDict() + if isinstance(loaded_dataset, (IterableDataset, IterableDatasetDict)) + else DatasetDict() + ) + + splits_to_keep = [train_split, eval_split] + if isinstance(loaded_dataset, (Dataset, IterableDataset)): + # Assume all is train split + raw_datasets[train_split] = loaded_dataset + else: + for k, v in loaded_dataset.items(): + if k in splits_to_keep: + raw_datasets[k] = v + + if d.data_handlers: # Execute the datahandlers + for data_handler_config in d.data_handlers: + raw_datasets = self._execute_data_handlers( + raw_datasets=raw_datasets, + data_handler_config=data_handler_config, + datasetName=d.name, + ) + + # Append the processed datasets to the final dict + processed_datasets.append((d, raw_datasets)) + logger.info( + "Sampling probabilities are ignored if provided" + "and are not used for concatenation. Instead" + "online data mixing plugin handles it." + ) + return self._process_datasets_for_odm(processed_datasets) + + def get_datapreprocessor( processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer, processor: AutoProcessor = None, additional_data_handlers: Dict[str, DataHandler] = None, ) -> DataPreProcessor: - data_processor = DataPreProcessor( + data_processor_cls = DataPreProcessor + if processor_config.type == "odm": + data_processor_cls = ODMDataPreProcessor + data_processor = data_processor_cls( processor_config=processor_config, tokenizer=tokenizer, processor=processor, diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 589c7a7916..34169b4b71 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -73,7 +73,6 @@ def process_dataconfig_file( processor: AutoProcessor = None, is_multipack: bool = False, is_padding_free: bool = False, - odm_config: ODMConfig = None, ): """ Args: @@ -155,7 +154,7 @@ def process_dataconfig_file( tokenizer.chat_template = data_processor.processor_config.chat_template train_dataset, eval_dataset = data_processor.process_dataset_configs( - data_config.datasets, odm_config=odm_config + data_config.datasets ) return (train_dataset, eval_dataset, data_args.dataset_text_field) @@ -348,7 +347,6 @@ def _process_raw_data_args( additional_data_handlers: Dict[str, DataHandler] = None, is_padding_free: bool = False, processor: AutoProcessor = None, - odm_config: ODMConfig = None, ): if data_args.data_config_path is not None: @@ -448,7 +446,7 @@ def _process_raw_data_args( dataset_configs.append(eval_dataset_config) train_dataset, eval_dataset = data_processor.process_dataset_configs( - dataset_configs, odm_config=odm_config + dataset_configs ) return (train_dataset, eval_dataset, dataset_text_field) @@ -635,7 +633,6 @@ def process_dataargs( processor, is_multipack, is_padding_free, - odm_config=odm_config, ) else: train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( @@ -646,7 +643,6 @@ def process_dataargs( additional_data_handlers, is_padding_free, processor, - odm_config=odm_config, ) if train_args.eval_strategy != "no" and eval_dataset is None: diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index edafadb9ca..330606946f 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -130,9 +130,9 @@ def train( ) resume_from_checkpoint = None - if training_args.output_dir: - os.makedirs(training_args.output_dir, exist_ok=True) - logger.info("using the output directory at %s", training_args.output_dir) + if train_args.output_dir: + os.makedirs(train_args.output_dir, exist_ok=True) + logger.info("using the output directory at %s", train_args.output_dir) # Check if resume flag is not passed (None), or if flag is true and # output_dir has checkpoints then get last checkpoint from output_dir From 96c3c96445997fba2e9c48a7cd4b40ec1211da8f Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 9 Oct 2025 18:05:27 +0530 Subject: [PATCH 6/9] fix: refactor code Signed-off-by: Mehant Kammakomati --- tests/data/test_data_preprocessing.py | 8 +- tuning/data/data_processors.py | 148 ++++++++++---------------- tuning/data/setup_dataprocessor.py | 43 +++++--- 3 files changed, 91 insertions(+), 108 deletions(-) diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index 1ddfaaf896..22ff407482 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -1831,7 +1831,9 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname): tokenizer=tokenizer, ) datasetconfig = [DataSetConfig(name=datasetconfigname, data_paths=[datafile])] - train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig) + train_dataset, _, _ = processor.process_dataset_configs( + dataset_configs=datasetconfig + ) assert isinstance(train_dataset, Dataset) assert set(train_dataset.column_names) == column_names @@ -1953,7 +1955,9 @@ def test_rename_and_select_dataset_columns( name=datasetconfigname, data_paths=data_paths, data_handlers=handlers ) ] - train_dataset, _ = processor.process_dataset_configs(dataset_configs=datasetconfig) + train_dataset, _, _ = processor.process_dataset_configs( + dataset_configs=datasetconfig + ) assert isinstance(train_dataset, Dataset) assert set(train_dataset.column_names) == set(final) diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index d7784e1c31..20936f387b 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -452,10 +452,9 @@ def split_dataset( ) return split_datasets - def _process_dataset_configs( + def _prepare_processed_datasets( self, dataset_configs: List[DataSetConfig] - ) -> Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]]: - + ) -> List[Tuple[DataSetConfig, Union[IterableDataset, Dataset]]]: if not dataset_configs: raise ValueError( "No dataset configs provided. Provided Dataset configs is None." @@ -504,6 +503,34 @@ def _process_dataset_configs( # Append the processed datasets to the final dict processed_datasets.append((d, raw_datasets)) + return processed_datasets + + def _validate_sampling_ratios(self, sampling_ratios: List[float], train_datasets): + if len(sampling_ratios) > 0: + if len(sampling_ratios) < len(train_datasets): + raise ValueError( + "Sampling probability should be specified for all datasets with train split" + ) + if len(sampling_ratios) > len(train_datasets): + raise ValueError( + "Sampling probability should only be specified for datasets with train split" + ) + if sum(p for p in sampling_ratios) != 1: + raise ValueError( + "Sampling probabilities for train datasets don't sum to 1" + ) + return True + + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig] + ) -> Tuple[ + Union[Dataset, IterableDataset], + Union[Dataset, IterableDataset], + Dict[str, float], + ]: + train_split = "train" # default + eval_split = "test" + processed_datasets = self._prepare_processed_datasets(dataset_configs) train_datasets = [] train_sampling_probabilities = [] @@ -525,25 +552,9 @@ def _process_dataset_configs( ) # quick check to see if we are sampling and if we need to throw error. - if len(train_sampling_probabilities) > 0: - if len(train_sampling_probabilities) < len(train_datasets): - raise ValueError( - "Sampling probability should be specified for all datasets with train split" - ) - if len(train_sampling_probabilities) > len(train_datasets): - raise ValueError( - "Sampling probability should only be specified for datasets with train split" - ) - if sum(p for p in train_sampling_probabilities) != 1: - raise ValueError( - "Sampling probabilities for train datasets don't sum to 1" - ) - sample_datasets = True - logger.info( - "Sampling ratios are specified; only train datasets will be interleaved." - ) - else: - sample_datasets = False + sample_datasets = self._validate_sampling_ratios( + train_sampling_probabilities, train_datasets + ) # Ensure again datasets are aligned before interleaving or concatenating maybe_align_datasets(train_datasets) @@ -588,11 +599,15 @@ def _process_dataset_configs( if eval_dataset and isinstance(eval_dataset, IterableDataset): eval_dataset = resolve_iterable_dataset_features(eval_dataset) - return train_dataset, eval_dataset + return train_dataset, eval_dataset, None def process_dataset_configs( self, dataset_configs: List[DataSetConfig] - ) -> Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]]: + ) -> Tuple[ + Union[Dataset, IterableDataset], + Union[Dataset, IterableDataset], + Dict[str, float], + ]: train_dataset = eval_dataset = None # Use partial state as recommended by HF documentation for process control @@ -605,96 +620,43 @@ def process_dataset_configs( # as we want to reuse HF cache and not redo computation on all nodes # For rationale see https://github.com/huggingface/trl/pull/3106 with state.main_process_first(): - train_dataset, eval_dataset = self._process_dataset_configs(dataset_configs) + ( + train_dataset, + eval_dataset, + sampling_weights, + ) = self._process_dataset_configs(dataset_configs) logger.info("Processed train dataset {}".format(train_dataset)) logger.info("Processed eval dataset {}".format(eval_dataset)) - return train_dataset, eval_dataset + return train_dataset, eval_dataset, sampling_weights class ODMDataPreProcessor(DataPreProcessor): - def _process_datasets_for_odm( - self, - processed_datasets: List[ - Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]] - ], + def _process_dataset_configs( + self, dataset_configs: List[DataSetConfig] ) -> Tuple[ Dict[str, Union[Dataset, IterableDataset]], Dict[str, Union[Dataset, IterableDataset]], + Dict[str, float], ]: + processed_datasets = self._prepare_processed_datasets(dataset_configs) train_split = "train" eval_split = "test" train_datasets_dict = {} eval_datasets_dict = {} + sampling_weights_dict = {} for d, raw in processed_datasets: + if d.sampling is not None and d.sampling > 0.0: + sampling_weights_dict[d.name] = d.sampling if train_split in raw: train_datasets_dict[d.name] = raw[train_split] if eval_split in raw: eval_datasets_dict[d.name] = raw[eval_split] - return train_datasets_dict, eval_datasets_dict - - def _process_dataset_configs( - self, dataset_configs: List[DataSetConfig] - ) -> Tuple[ - Dict[str, Union[Dataset, IterableDataset]], - Dict[str, Union[Dataset, IterableDataset]], - ]: - - if not dataset_configs: - raise ValueError( - "No dataset configs provided. Provided Dataset configs is None." - ) - - train_split = "train" # default - eval_split = "test" - - processed_datasets = [] - - logger.info("Starting DataPreProcessor...") - # Now Iterate over the multiple datasets provided to us to process - for d in dataset_configs: - logger.info("Loading the dataset - %s", d.name) - - # In future the streaming etc go as kwargs of this function - loaded_dataset = self.load_dataset(d, self.processor_config.streaming) - logger.info("Loaded raw dataset : %s", str(loaded_dataset)) - - if d.split is not None: - loaded_dataset = self.split_dataset(d, loaded_dataset) - - # Create a raw dataset which is a Dict container to house all Datasets - raw_datasets = ( - IterableDatasetDict() - if isinstance(loaded_dataset, (IterableDataset, IterableDatasetDict)) - else DatasetDict() - ) - - splits_to_keep = [train_split, eval_split] - if isinstance(loaded_dataset, (Dataset, IterableDataset)): - # Assume all is train split - raw_datasets[train_split] = loaded_dataset - else: - for k, v in loaded_dataset.items(): - if k in splits_to_keep: - raw_datasets[k] = v - - if d.data_handlers: # Execute the datahandlers - for data_handler_config in d.data_handlers: - raw_datasets = self._execute_data_handlers( - raw_datasets=raw_datasets, - data_handler_config=data_handler_config, - datasetName=d.name, - ) - - # Append the processed datasets to the final dict - processed_datasets.append((d, raw_datasets)) - logger.info( - "Sampling probabilities are ignored if provided" - "and are not used for concatenation. Instead" - "online data mixing plugin handles it." + self._validate_sampling_ratios( + sampling_weights_dict.values(), train_datasets_dict.values() ) - return self._process_datasets_for_odm(processed_datasets) + return train_datasets_dict, eval_datasets_dict, sampling_weights_dict def get_datapreprocessor( diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 34169b4b71..bc2f843154 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -14,7 +14,7 @@ # Standard from pathlib import Path -from typing import Dict, Union +from typing import Dict, List, Union import logging # Third Party @@ -89,11 +89,12 @@ def process_dataconfig_file( is_multipack: A bool representing is Multipack plugin is enabled. Defauts to False. Returns: - Tuple(Dataset, Dataset, str) + Tuple(Dataset, Dataset, str, Dict[str, float]) tuple containing train_dataset (Dataset/IterableDataset), eval_dataset (Dataset/IterableDataset), dataset_text_field (str), + sampling weights """ data_config = load_and_validate_data_config(data_args.data_config_path) @@ -153,11 +154,13 @@ def process_dataconfig_file( ) tokenizer.chat_template = data_processor.processor_config.chat_template - train_dataset, eval_dataset = data_processor.process_dataset_configs( - data_config.datasets - ) + ( + train_dataset, + eval_dataset, + sampling_weights, + ) = data_processor.process_dataset_configs(data_config.datasets) - return (train_dataset, eval_dataset, data_args.dataset_text_field) + return (train_dataset, eval_dataset, data_args.dataset_text_field, sampling_weights) # Data Format 1: Pretokenized Data @@ -445,11 +448,13 @@ def _process_raw_data_args( if is_eval_dataset_present: dataset_configs.append(eval_dataset_config) - train_dataset, eval_dataset = data_processor.process_dataset_configs( - dataset_configs - ) + ( + train_dataset, + eval_dataset, + sampling_weights, + ) = data_processor.process_dataset_configs(dataset_configs) - return (train_dataset, eval_dataset, dataset_text_field) + return (train_dataset, eval_dataset, dataset_text_field, sampling_weights) def dump_dataset( @@ -503,6 +508,7 @@ def setup_train_dataset_for_odm( train_dataset: Dict = None, reward_dataset: Dict = None, # eval_dataset is used for reward computation max_seq_length: str = None, + sampling_weights: List[float] = None, # cold start sampling weights for ODM ): # pylint: disable=import-outside-toplevel if not is_fms_accelerate_available(plugins="odm"): @@ -547,7 +553,7 @@ def setup_train_dataset_for_odm( collators, reward_dataset, eval_collators, - None, + sampling_weights, gamma=odm_config.odm.gamma, eta=odm_config.odm.eta, output_dir=train_args.output_dir, @@ -625,7 +631,12 @@ def process_dataargs( ) if data_args.data_config_path: - train_dataset, eval_dataset, dataset_text_field = process_dataconfig_file( + ( + train_dataset, + eval_dataset, + dataset_text_field, + sampling_weights, + ) = process_dataconfig_file( data_args, train_args, tokenizer, @@ -635,7 +646,12 @@ def process_dataargs( is_padding_free, ) else: - train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( + ( + train_dataset, + eval_dataset, + dataset_text_field, + sampling_weights, + ) = _process_raw_data_args( data_args, tokenizer, train_args.packing, @@ -696,6 +712,7 @@ def process_dataargs( train_dataset, eval_dataset, max_seq_length, + sampling_weights, ) else: # Note: This check should not be removed. From 7683205555e1a9cdea7fdcd25b10428063ac1e21 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 9 Oct 2025 18:13:34 +0530 Subject: [PATCH 7/9] fix: refactor code Signed-off-by: Mehant Kammakomati --- .../multiple_datasets_with_odm.yaml | 6 +++--- tests/test_sft_trainer.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml b/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml index 6b5bda6986..74c72a5fdd 100644 --- a/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml +++ b/tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml @@ -13,7 +13,7 @@ datasets: split: train: 0.8 validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss. - sampling: 0.3 # ignored + sampling: 0.3 # used as starting weights for online data mixing data_paths: - "FILE_PATH" data_handlers: @@ -28,7 +28,7 @@ datasets: split: train: 0.6 validation: 0.2 # validation set is also used in ODM reward computation when reward_type is validation_loss. - sampling: 0.4 # ignored + sampling: 0.4 # used as starting weights for online data mixing data_paths: - "FILE_PATH" data_handlers: @@ -43,7 +43,7 @@ datasets: split: train: 0.4 validation: 0.1 # validation set is also used in ODM reward computation when reward_type is validation_loss. - sampling: 0.3 # ignored + sampling: 0.3 # used as starting weights for online data mixing data_paths: - "FILE_PATH" data_handlers: diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index aee20293ea..bea4921501 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -2263,8 +2263,12 @@ def test_online_data_mixing_plugin_sample_training( data = yaml.safe_load(f) data["dataprocessor"]["odm"]["reward_type"] = reward_type data["datasets"] = data["datasets"][:2] + sampling_weights = [0.4, 0.6] + i = 0 for d, df in zip(data["datasets"], datafiles): d["data_paths"] = [df] + d["sampling"] = sampling_weights[i] + i += 1 yaml.dump(data, temp_yaml_file) data_formatting_args.data_config_path = temp_yaml_file.name @@ -2342,9 +2346,12 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split( data = yaml.safe_load(f) data["datasets"] = data["datasets"][:2] data["dataprocessor"]["odm"]["reward_type"] = reward_type + i = 0 for d, df in zip(data["datasets"], datafiles): d["data_paths"] = [df] + d["sampling"] = sampling_weights[i] del d["split"] + i += 1 yaml.dump(data, temp_yaml_file) data_formatting_args.data_config_path = temp_yaml_file.name From 090ee2f9c34035dff4b25b7dd9d8695398b4955a Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 9 Oct 2025 18:15:22 +0530 Subject: [PATCH 8/9] fix: refactor code Signed-off-by: Mehant Kammakomati --- tests/test_sft_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index bea4921501..f9f1d3810b 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -2347,6 +2347,7 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split( data["datasets"] = data["datasets"][:2] data["dataprocessor"]["odm"]["reward_type"] = reward_type i = 0 + sampling_weights = [0.4, 0.6] for d, df in zip(data["datasets"], datafiles): d["data_paths"] = [df] d["sampling"] = sampling_weights[i] From dd2532ef163cdf7506ce36fb32da5a36901ffe6f Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 9 Oct 2025 18:53:04 +0530 Subject: [PATCH 9/9] fix: refactor sampling weight Signed-off-by: Mehant Kammakomati --- tests/data/test_data_preprocessing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/data/test_data_preprocessing.py b/tests/data/test_data_preprocessing.py index 22ff407482..a1072d2ece 100644 --- a/tests/data/test_data_preprocessing.py +++ b/tests/data/test_data_preprocessing.py @@ -768,7 +768,7 @@ def test_process_dataconfig_file_with_streaming(data_config_path, data_path): output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, IterableDataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1017,7 +1017,7 @@ def test_process_dataconfig_file(data_config_path, data_path): output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1107,7 +1107,7 @@ def test_process_datahandler_eos_token(data_config_path, data_path, add_eos_toke output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1258,7 +1258,7 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list): output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -1330,7 +1330,7 @@ def test_process_dataconfig_multiple_files_folders_with_globbing( output_dir="tmp", # Not needed but positional ) - (train_set, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) + (train_set, _, _, _) = process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer) assert isinstance(train_set, Dataset) assert set(["input_ids", "attention_mask", "labels"]).issubset( set(train_set.column_names)