From 7294240a426f78dc1812411568b0360f52571e11 Mon Sep 17 00:00:00 2001 From: romitjain Date: Wed, 3 Dec 2025 11:59:47 +0000 Subject: [PATCH] Updated ODM defaults Signed-off-by: romitjain --- docs/advanced-data-preprocessing.md | 12 +++++++----- tuning/config/acceleration_configs/odm.py | 6 +++--- tuning/data/setup_dataprocessor.py | 2 +- tuning/sft_trainer.py | 3 +++ 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/advanced-data-preprocessing.md b/docs/advanced-data-preprocessing.md index b08e5476be..0289575969 100644 --- a/docs/advanced-data-preprocessing.md +++ b/docs/advanced-data-preprocessing.md @@ -147,11 +147,13 @@ Each data handler has: `odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`. `odm`: - `update_interval` (optional, int, defaults to `1`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count. - `sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count. - `reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards). - `gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor. - `eta` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to learning rate. +- `update_interval` (optional, int, defaults to `None`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count. If not provided, it defaults to `eval_steps` +- `sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count. +- `reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards). +- `gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor. +- `eta` (optional, int, defaults to `0.3`): MAB hyper-parameter which is similar to learning rate. +- `auto_categorize_input_column` (optional, str, defaults to `None`): If only a single dataset is provided, this field is required to determin the column name which should be used to categorize the data into psuedo categories +- `auto_categorize_num_categories` (optional, int, defaults to `None`): Used in conjunction with the above field, this field specifies the number of psuedo categories to be assigned in the dataset `datasets` (list): - `name` (optional, str): A unique identifier for the dataset. diff --git a/tuning/config/acceleration_configs/odm.py b/tuning/config/acceleration_configs/odm.py index 83a3b69cbb..31ede29692 100644 --- a/tuning/config/acceleration_configs/odm.py +++ b/tuning/config/acceleration_configs/odm.py @@ -24,10 +24,10 @@ @dataclass class ODM: update_interval: int = None - sampling_interval: int = None - reward_type: str = None + sampling_interval: int = 1 + reward_type: str = "entropy" gamma: float = 0.1 - eta: float = 0.1 + eta: float = 0.3 resume_from_checkpoint: Union[bool, str] = False auto_categorize_input_column: str = None auto_categorize_num_categories: Optional[int] = None diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 9eb837232f..9b6b4ce23f 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -549,7 +549,7 @@ def setup_train_dataset_for_odm( ) auto_categorize_config = {} - if hasattr(odm_config.odm, "auto_categorize_input_column"): + if odm_config.odm.auto_categorize_input_column: auto_categorize_config = { "input_column": "input_ids", "num_categories": int(odm_config.odm.auto_categorize_num_categories), diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index dc6a74174b..7a6a11d46a 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -160,6 +160,9 @@ def train( "resume_from_checkpoint" ] = resume_from_checkpoint odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm)) + odm_config.odm.update_interval = ( + odm_config.odm.update_interval or train_args.eval_steps + ) # Validate parameters if (not isinstance(model_args.model_name_or_path, str)) or (