Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/advanced-data-preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,13 @@ Each data handler has:
`odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`.

`odm`:
`update_interval` (optional, int, defaults to `1`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count.
`sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
`reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
`gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
`eta` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to learning rate.
- `update_interval` (optional, int, defaults to `None`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count. If not provided, it defaults to `eval_steps`
- `sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
- `reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
- `gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
- `eta` (optional, int, defaults to `0.3`): MAB hyper-parameter which is similar to learning rate.
- `auto_categorize_input_column` (optional, str, defaults to `None`): If only a single dataset is provided, this field is required to determin the column name which should be used to categorize the data into psuedo categories
- `auto_categorize_num_categories` (optional, int, defaults to `None`): Used in conjunction with the above field, this field specifies the number of psuedo categories to be assigned in the dataset

`datasets` (list):
- `name` (optional, str): A unique identifier for the dataset.
Expand Down
6 changes: 3 additions & 3 deletions tuning/config/acceleration_configs/odm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
@dataclass
class ODM:
update_interval: int = None
sampling_interval: int = None
reward_type: str = None
sampling_interval: int = 1
reward_type: str = "entropy"
gamma: float = 0.1
eta: float = 0.1
eta: float = 0.3
resume_from_checkpoint: Union[bool, str] = False
auto_categorize_input_column: str = None
auto_categorize_num_categories: Optional[int] = None
Expand Down
2 changes: 1 addition & 1 deletion tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def setup_train_dataset_for_odm(
)

auto_categorize_config = {}
if hasattr(odm_config.odm, "auto_categorize_input_column"):
if odm_config.odm.auto_categorize_input_column:
auto_categorize_config = {
"input_column": "input_ids",
"num_categories": int(odm_config.odm.auto_categorize_num_categories),
Expand Down
3 changes: 3 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def train(
"resume_from_checkpoint"
] = resume_from_checkpoint
odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm))
odm_config.odm.update_interval = (
odm_config.odm.update_interval or train_args.eval_steps
)

# Validate parameters
if (not isinstance(model_args.model_name_or_path, str)) or (
Expand Down
Loading