Skip to content

Commit 3eccc8a

Browse files
committed
docs: add docs
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent dcc2d07 commit 3eccc8a

3 files changed

Lines changed: 54 additions & 20 deletions

File tree

plugins/online-data-mixing/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,42 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks
1313

1414
## Usage in Custom Training Loop
1515

16+
![](./artifacts/plot.png)
17+
18+
`OnlineMixingDataset` can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found [here](./artifacts/custom_loop_usage.py). Given code sample uses two instruction tuning datasets and trains `ibm-granite/granite-3.1-2b-instruct` model for next token prediction task.
19+
20+
## Metrics
21+
22+
All metrics related to the online data mixing will be logged to `odm.jsonl` file in the checkpoint output directory.
23+
24+
Metric | Description
25+
--|--
26+
`samples_produced_so_far` | Total samples produced by the dataset so far at the time of logging.
27+
`sampling_interval` | Takes sample count "n" as input. At every "n" steps category/dataset chosen by weighted random sampling where weights are provided by the Multi-Armed Bandit algorithm.
28+
`total_categories` | Total categories or datasets involved in mixing.
29+
`current_sampling_weights` | Current state of the sampling weights at the time of logging.
30+
`current_sampling_ratio` | Current state of the sampling ratios at the time of logging.
31+
`arm_idx` | Last sampled category index. Categories/datasets are sorted in ascending order based on their names and index starts from 0 and each index corresponds to respective category/dataset.
32+
`category_level_counts_so_far` | Split of sample count across datasets so far at the time of logging.
33+
`rewards` | State of the rewards at the time of logging. Essentially are the last provided rewards across datasets.
34+
`action` | Type of action took place at the time logging. It is either "update" or "sample" which correspond to weight update of the MAB algorithm or category sampling.
35+
36+
## Rewards
37+
38+
Below are the currently available rewards and we are constantly looking to improve the existing rewards and also add new ones. Further, we encourage users to identify rewards that can help their usecases.
39+
40+
Rewards | Description
41+
--|--
42+
`ENTROPY` | Calculation of shannon entropy of the logits averaged across all the tokens. Higher entropy would mean model requires more samples from that datasets/category.
43+
`ENTROPY3_VARENT1` | 3 parts of shannon entropy and 1 part of variance of the entropy. Higher values mean requirement of more samples.
44+
`ENTROPY_LAST_TOKEN` | Shannon entropy of the last token in the sample. Higher values mean requirement of more samples.
45+
`TRAIN_LOSS` | Training loss where loss is maintained across categories and is updated based on the latest loss and sampled dataset/category. Higher values mean requirement of more samples.
46+
`VALIDATION_LOSS` | Validation loss across categories calculated using evaluation datasets from each of the categories. Higher values mean requirement of more samples.
47+
`GRADNORM` | Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category.
48+
49+
### Adding a custom reward
50+
Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`.
51+
1652

1753
### Planned TODOs
1854
Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153).

plugins/online-data-mixing/artifacts/custom_loop_usage.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
# Run commmand
2+
# CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file fms-acceleration/scripts/benchmarks/accelerate.yaml
3+
# --num_processes=2 --main_process_port=29511 custom_loop_usage.py
4+
15
# Standard
26
import json
37
import os
48

59
# Third Party
610
from accelerate import Accelerator, DataLoaderConfiguration
7-
from datasets import concatenate_datasets, load_dataset
11+
from datasets import load_dataset
812
from torch.utils.data import DataLoader
913
from tqdm import tqdm
1014
from transformers import (
@@ -19,10 +23,14 @@
1923

2024
model_name = "ibm-granite/granite-3.1-2b-instruct"
2125
output_dir = "./odm_custom_use"
22-
max_steps = 400
26+
max_steps = 125
2327
batch_size = 12
2428
log_file = os.path.join(output_dir, "loss.jsonl")
2529

30+
# odm related
31+
step_idx = 0
32+
update_interval = 1 # every step
33+
2634
# model
2735
model = AutoModelForCausalLM.from_pretrained(model_name)
2836

@@ -38,10 +46,6 @@ def tokenize_fn(examples):
3846
)
3947

4048

41-
# Third Party
42-
from datasets import load_dataset
43-
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
44-
4549
dataset_dict = {
4650
"alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"),
4751
"oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"),
@@ -53,8 +57,6 @@ def format_example(example):
5357
prompt = f"Instruction: {example['instruction']}\nInput: {example.get('input','')}\nOutput: {example['output']}"
5458
elif "text" in example:
5559
prompt = example["text"]
56-
else:
57-
raise ValueError("Dataset schema not supported")
5860
return {"text": prompt}
5961

6062

@@ -83,8 +85,7 @@ def tokenize_fn(examples):
8385
for name in dataset_dict
8486
}
8587

86-
# odm related
87-
update_interval = 1 # every step
88+
# dataset preparation
8889
dataset = OnlineMixingDataset(
8990
dataset_dict=dataset_dict,
9091
collators_dict=collator_dict,
@@ -104,17 +105,17 @@ def tokenize_fn(examples):
104105
# training setup
105106
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
106107

107-
model.train()
108-
109-
step_idx = 0
110-
111108

109+
# Trainer state
112110
class State:
113111
log_history: list = []
114112

115113

116114
state = State()
115+
116+
117117
# custom training loop
118+
model.train()
118119
for step, batch in enumerate(
119120
tqdm(dataloader, disable=not accelerator.is_local_main_process)
120121
):
@@ -141,7 +142,4 @@ class State:
141142
if step_idx > max_steps:
142143
break
143144

144-
print("training completed!")
145-
146-
147-
# CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file /workspace/fms-acceleration/scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29511 custom_loop_usage.py
145+
print("Training completed!")

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
"total_categories": self.total_categories,
155155
"current_sampling_weights": self.sampling_weights.tolist(),
156156
"current_sampling_ratio": self.sampling_ratio,
157-
"arm_dix": self.arm_idx,
157+
"arm_idx": self.arm_idx,
158158
"category_level_counts_so_far": self.curr_cat_count,
159159
"rewards": [0] * self.total_categories,
160160
"count": 0,
@@ -223,7 +223,7 @@ def __next__(self):
223223

224224
self.log_to_file(
225225
{
226-
"arm_dix": self.arm_idx,
226+
"arm_idx": self.arm_idx,
227227
"samples_produced_so_far": self.produced,
228228
"category_level_counts_so_far": self.curr_cat_count,
229229
"action": "sample",

0 commit comments

Comments
 (0)