Skip to content

Commit e206f23

Browse files
committed
fix: refactor
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent d5db867 commit e206f23

4 files changed

Lines changed: 102 additions & 130 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- [Advanced Data Processing](./docs/advanced-data-preprocessing.md#data-config)
99
- [Guidelines on supported data formats](./docs/advanced-data-preprocessing.md#use-cases-supported-via-command-line-argument-training_data_path)
1010
- [Offline data processing](#offline-data-preprocessing)
11+
- [Online data mixing](./docs/online-data-mixing.md)
1112
- [Additional Frameworks](#additional-frameworks)
1213
- [Inference](#inference)
1314
- [Validation](#validation)

docs/advanced-data-preprocessing.md

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ Our library also supports a powerful data processing backend which can be used b
44
1. Creating custom data processing pipeline for the datasets.
55
1. Combining multiple datasets into one, even if they have different formats.
66
1. Mixing datasets as required and sampling each dataset with different weights.
7-
1. Dynamically mixing datasets online based on training signals through fms_acceleration_odm plugin.
87

98
These things are supported via what we call a [`data_config`](#data-config) which can be passed as an argument to sft trainer.
109

@@ -137,6 +136,14 @@ Users can create a data config file in any of YAML or JSON format they choose (w
137136
- `chat_template` (optional, str): pass `chat_template` via data_config for multi-turn data, replaces existing default chat template.
138137
- `odm` (optional): if `type` is odm, this field is required to be specific to provide configuration for online data mixing.
139138

139+
Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
140+
These functions can process the dataset in any way users require and the `list` of data handlers specified for each dataset are applied in order.
141+
Each data handler has:
142+
- `name`: The handler's unique identifier.
143+
- `arguments`: A dictionary of parameters specific to the handler.
144+
145+
#### Online data mixing section
146+
140147
`odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`.
141148

142149
`odm`:
@@ -154,11 +161,6 @@ Users can create a data config file in any of YAML or JSON format they choose (w
154161
- `split` (optional, dict[str: float]): Defines how to split the dataset into training and validation sets. Requires both `train` and `validation` keys.
155162
- `data_handlers` (optional, list): A list of data handler configurations which preprocess the dataset.
156163

157-
Data handlers are customizable components within the data config that allow users to preprocess or manipulate individual datasets. We use [Hugging Face Map API](https://huggingface.co/docs/datasets/en/process#map) to apply these routines.
158-
These functions can process the dataset in any way users require and the `list` of data handlers specified for each dataset are applied in order.
159-
Each data handler has:
160-
- `name`: The handler's unique identifier.
161-
- `arguments`: A dictionary of parameters specific to the handler.
162164

163165
We do provide some sample `data_configs` here, [predefined_data_configs](../tests/artifacts/predefined_data_configs/).
164166

@@ -203,58 +205,6 @@ We also allow users to pass a [`seed`](https://huggingface.co/docs/datasets/v3.2
203205

204206
Note: If a user specifies data sampling they can expect the datasets to be mixed and individual samples in the dataset to not be broken unless the max_seq_len argument is smaller than the length of individual samples in the dataset
205207

206-
### Online Data Mixing
207-
Dataset mixing can be dynamic in nature that adapts online during the training based on the training signals. We provide this feature through fms_acceleration_odm plugin and more details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing).
208-
209-
#### How to Use
210-
211-
`dataprocessor` `type` has to be set to `odm` and then `odm` config should be provided in the `odm` section of the data config file. An example is shown below:
212-
213-
```yaml
214-
dataprocessor:
215-
type: odm
216-
odm:
217-
update_interval: 1 # update every step
218-
sampling_interval: 1 # sample category for every sample
219-
reward_type: validation_loss # uses eval loss of each dataset as reward
220-
gamma: 0.1 # MAB hyper-parameter
221-
eta: 0.2 # MAB hyper-parameter
222-
```
223-
224-
Here `update_interval` is set to `1` which is to update MAB on every step with validation loss as reward across the datasets. `sampling_interval` is set to `1` which is to choose a dataset to sample for every sample. `reward_type` is set to `validation_loss` to use validation loss across datasets as a training signal to reward MAB decisions during training. Example `datasets` section can look like below:
225-
226-
```yaml
227-
datasets:
228-
- name: dataset_1
229-
split:
230-
train: 0.8
231-
validation: 0.2
232-
data_paths:
233-
- "FILE_PATH"
234-
data_handlers:
235-
- name: tokenize_and_apply_input_masking
236-
arguments:
237-
remove_columns: all
238-
batched: false
239-
fn_kwargs:
240-
input_column_name: input
241-
output_column_name: output
242-
- name: dataset_2
243-
split:
244-
train: 0.9
245-
validation: 0.1
246-
data_paths:
247-
- "FILE_PATH"
248-
data_handlers:
249-
- name: tokenize_and_apply_input_masking
250-
arguments:
251-
remove_columns: all
252-
batched: false
253-
fn_kwargs:
254-
input_column_name: input
255-
output_column_name: output
256-
```
257-
As you notice, `validation` under `split` is provided for each of the datasets and is necessary to be provided since the `reward_type` is `validation_loss` which requires validation datasets to be available. Same applies to the following rewards: `validation_loss`, `entropy`, `entropy3_varent1`, and `entropy_last_token`. While reward_types `train_loss` and `gradnorm` do not require validation split.
258208

259209
### Dataset Splitting
260210

tests/artifacts/predefined_data_configs/multiple_datasets_with_odm.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ datasets:
1212
- name: dataset_1
1313
split:
1414
train: 0.8
15-
validation: 0.2
15+
validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss.
1616
sampling: 0.3 # ignored
1717
data_paths:
1818
- "FILE_PATH"
@@ -27,7 +27,7 @@ datasets:
2727
- name: dataset_2
2828
split:
2929
train: 0.6
30-
validation: 0.2
30+
validation: 0.2 # validation set is also used in reward computation when reward_type is validation_loss.
3131
sampling: 0.4 # ignored
3232
data_paths:
3333
- "FILE_PATH"
@@ -42,7 +42,7 @@ datasets:
4242
- name: dataset_3
4343
split:
4444
train: 0.4
45-
validation: 0.1
45+
validation: 0.1 # validation set is also used in reward computation when reward_type is validation_loss.
4646
sampling: 0.3 # ignored
4747
data_paths:
4848
- "FILE_PATH"
@@ -57,7 +57,7 @@ datasets:
5757
- name: dataset_4
5858
split:
5959
train: 0.0
60-
validation: 0.3 # ignored
60+
validation: 0.3 # validation set is also used in reward computation when reward_type is validation_loss.
6161
data_paths:
6262
- "FILE_PATH"
6363
data_handlers:
@@ -67,4 +67,4 @@ datasets:
6767
batched: false
6868
fn_kwargs:
6969
input_column_name: input
70-
output_column_name: output
70+
output_column_name: output

tuning/data/setup_dataprocessor.py

Lines changed: 88 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
12+
# See the License for the specificm language governing permissions and
1313
# limitations under the License.
1414

1515
# Standard
@@ -495,6 +495,72 @@ def dump_dataset(
495495
raise RuntimeError(f"Failed to dump dataset due to error {e}") from e
496496

497497

498+
def process_dataargs_odm(
499+
data_args: DataArguments,
500+
tokenizer: AutoTokenizer,
501+
train_args: TrainingArguments,
502+
is_padding_free: bool = False,
503+
processor: AutoProcessor = None,
504+
odm_config: ODMConfig = None,
505+
train_dataset: Dict = None,
506+
eval_dataset: Dict = None,
507+
max_seq_length: str = None,
508+
):
509+
collators = {}
510+
eval_collators = {}
511+
for k, v in train_dataset.items():
512+
is_tokenized_dataset = is_pretokenized_dataset(v)
513+
collators[k] = get_data_collator(
514+
train_args.packing,
515+
data_args.response_template,
516+
tokenizer,
517+
is_tokenized_dataset,
518+
max_seq_length,
519+
data_args.instruction_template,
520+
is_padding_free=is_padding_free,
521+
processor=processor,
522+
)
523+
data_collator = collators[k]
524+
for k, v in eval_dataset.items():
525+
is_tokenized_dataset = is_pretokenized_dataset(v)
526+
eval_collators[k] = get_data_collator(
527+
train_args.packing,
528+
data_args.response_template,
529+
tokenizer,
530+
is_tokenized_dataset,
531+
max_seq_length,
532+
data_args.instruction_template,
533+
is_padding_free=is_padding_free,
534+
processor=processor,
535+
)
536+
537+
# pylint: disable=import-outside-toplevel
538+
if not is_fms_accelerate_available(plugins="odm"):
539+
raise ImportError(
540+
"use of odm data config feature requires"
541+
"installation of fms_acceleration_odm package"
542+
)
543+
# Third Party
544+
# pylint: disable=import-error
545+
from fms_acceleration_odm import OnlineMixingDataset
546+
547+
train_dataset = OnlineMixingDataset(
548+
train_dataset,
549+
collators,
550+
eval_dataset,
551+
eval_collators,
552+
None,
553+
gamma=odm_config.odm.gamma,
554+
eta=odm_config.odm.eta,
555+
output_dir=train_args.output_dir,
556+
sampling_interval=odm_config.odm.sampling_interval,
557+
eval_batch_size=train_args.per_device_eval_batch_size,
558+
reward_type=odm_config.odm.reward_type,
559+
)
560+
train_args.accelerator_config = {"split_batches": True}
561+
return (True, train_dataset, True, data_collator)
562+
563+
498564
# If a data config file is provided, load it to get the training dataset.
499565
# - Assumes only the training dataset is specified in the config file.
500566
# - Expects a complete and valid data config file from the user.
@@ -595,10 +661,6 @@ def process_dataargs(
595661
"Check your data config or ensure split sizes are valid."
596662
)
597663
if data_args.do_dataprocessing_only:
598-
if odm_config:
599-
raise ValueError(
600-
"data processing with online data mixing is not currently supported"
601-
)
602664
dump_dir = Path(train_args.output_dir)
603665
if not dump_dir.is_absolute():
604666
dump_dir = dump_dir.absolute()
@@ -621,44 +683,31 @@ def process_dataargs(
621683
)
622684
return (train_dataset, eval_dataset, None, None, None, None)
623685

624-
# Note: This check should not be removed.
625-
# Its important to recompute this post handling to
626-
# check if we already tokenized the dataset or not.
686+
dataset_kwargs = {}
687+
data_collator = None
627688
if odm_config:
628689
is_tokenized_dataset = True
690+
(
691+
dataset_kwargs["skip_prepare_dataset"],
692+
train_dataset,
693+
dataset_kwargs,
694+
data_collator,
695+
) = process_dataargs_odm(
696+
data_args,
697+
tokenizer,
698+
train_args,
699+
is_padding_free,
700+
processor,
701+
odm_config,
702+
train_dataset,
703+
eval_dataset,
704+
max_seq_length,
705+
)
629706
else:
707+
# Note: This check should not be removed.
708+
# Its important to recompute this post handling to
709+
# check if we already tokenized the dataset or not.
630710
is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset)
631-
632-
data_collator = None
633-
if odm_config:
634-
collators = {}
635-
eval_collators = {}
636-
for k, v in train_dataset.items():
637-
is_tokenized_dataset = is_pretokenized_dataset(v)
638-
collators[k] = get_data_collator(
639-
train_args.packing,
640-
data_args.response_template,
641-
tokenizer,
642-
is_tokenized_dataset,
643-
max_seq_length,
644-
data_args.instruction_template,
645-
is_padding_free=is_padding_free,
646-
processor=processor,
647-
)
648-
data_collator = collators[k]
649-
for k, v in eval_dataset.items():
650-
is_tokenized_dataset = is_pretokenized_dataset(v)
651-
eval_collators[k] = get_data_collator(
652-
train_args.packing,
653-
data_args.response_template,
654-
tokenizer,
655-
is_tokenized_dataset,
656-
max_seq_length,
657-
data_args.instruction_template,
658-
is_padding_free=is_padding_free,
659-
processor=processor,
660-
)
661-
else:
662711
data_collator = get_data_collator(
663712
train_args.packing,
664713
data_args.response_template,
@@ -669,34 +718,6 @@ def process_dataargs(
669718
is_padding_free=is_padding_free,
670719
processor=processor,
671720
)
672-
dataset_kwargs = {}
673-
if odm_config:
674-
# Third Party
675-
# pylint: disable=import-outside-toplevel
676-
if not is_fms_accelerate_available(plugins="odm"):
677-
raise ImportError(
678-
"use of odm data config feature requires"
679-
"installation of fms_acceleration_odm package"
680-
)
681-
# Third Party
682-
# pylint: disable=import-error
683-
from fms_acceleration_odm import OnlineMixingDataset
684-
685-
train_dataset = OnlineMixingDataset(
686-
train_dataset,
687-
collators,
688-
eval_dataset,
689-
eval_collators,
690-
None,
691-
gamma=odm_config.odm.gamma,
692-
eta=odm_config.odm.eta,
693-
output_dir=train_args.output_dir,
694-
sampling_interval=odm_config.odm.sampling_interval,
695-
eval_batch_size=train_args.per_device_eval_batch_size,
696-
reward_type=odm_config.odm.reward_type,
697-
)
698-
dataset_kwargs["skip_prepare_dataset"] = True
699-
train_args.accelerator_config = {"split_batches": True}
700721

701722
# For vision model tuning prepare_dataset is skipped.
702723
if processor is not None:

0 commit comments

Comments
 (0)