Skip to content

Commit dc3ef05

Browse files
committed
fix: add custom loop and docs
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 87fe0e8 commit dc3ef05

3 files changed

Lines changed: 92 additions & 0 deletions

File tree

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,21 @@
11
# Online Data Mixing
2+
3+
This library contains plugin for online dynamic reward (learnable) based data mixing framework that operates on dynamically mixing datasets online during training while being adapted based on the signals (e.g. training loss, gradnorm etc) from training.
4+
5+
## Plugins
6+
7+
Plugin | Description | Depends | Loading | Augmentation | Callbacks
8+
--|--|--|--|--|--
9+
[odm](./src/fms_acceleration_odm/framework_plugin_odm.py) | OnlineMixingDataset PyTorch IterableDataset and custom rewards | | ✅ | ✅ | ✅
10+
11+
## Design
12+
![](./artifacts/Design.png)
13+
14+
## Usage in Custom Training Loop
15+
16+
17+
### Planned TODOs
18+
Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153).
19+
20+
21+
521 KB
Loading
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from datasets import load_dataset, concatenate_datasets
2+
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
3+
from torch.utils.data import DataLoader
4+
from accelerate import Accelerator, DataLoaderConfiguration
5+
import torch
6+
from tqdm import tqdm
7+
from fms_acceleration_odm import OnlineMixingDataset
8+
9+
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
10+
output_dir = "./odm_custom_use"
11+
max_steps = 50
12+
13+
# model
14+
model = AutoModelForCausalLM.from_pretrained(model_name)
15+
16+
# tokenizer
17+
tokenizer = AutoTokenizer.from_pretrained(model_name)
18+
tokenizer.pad_token = tokenizer.eos_token
19+
20+
# dataset related
21+
def tokenize_fn(examples):
22+
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
23+
24+
dataset_dict = {
25+
"bookcorpus": load_dataset("rojagtap/bookcorpus", split="train[:1%]"),
26+
"wikitext": load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
27+
}
28+
29+
# tokenization
30+
dataset_dict["bookcorpus"] = dataset_dict["bookcorpus"].map(tokenize_fn, batched=True, remove_columns=dataset_dict["bookcorpus"].column_names)
31+
dataset_dict["wikitext"] = dataset_dict["wikitext"].map(tokenize_fn, batched=True, remove_columns=dataset_dict["wikitext"].column_names)
32+
33+
collator_dict = {
34+
"bookcorpus": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
35+
"wikitext": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
36+
}
37+
38+
# odm related
39+
update_interval=1 # every step
40+
dataset = OnlineMixingDataset(dataset_dict=dataset_dict,
41+
collators_dict=collator_dict,
42+
eval_dataset_dict=None,
43+
eval_collators_dict=None,
44+
output_dir=output_dir,
45+
reward_type="train_loss",
46+
sampling_interval=1)
47+
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=None)
48+
49+
# distributed setup
50+
dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True)
51+
accelerator = Accelerator(split_batches=True, dataloader_config=dataloader_config)
52+
model, dataloader = accelerator.prepare(model, dataloader)
53+
54+
# training setup
55+
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
56+
57+
model.train()
58+
59+
# custom training loop
60+
for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)):
61+
outputs = model(**batch)
62+
loss = outputs.loss
63+
accelerator.backward(loss)
64+
optimizer.step()
65+
optimizer.zero_grad()
66+
if step % 1 == 0:
67+
print(f"Step {step} | Loss: {loss.item():.4f}")
68+
max_steps -= 1
69+
if max_steps == 0:
70+
break
71+
72+
print("training completed!")

0 commit comments

Comments
 (0)