Skip to content

Commit c0106ff

Browse files
committed
fix: add custom loop and docs
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent dc3ef05 commit c0106ff

1 file changed

Lines changed: 36 additions & 17 deletions

File tree

plugins/online-data-mixing/artifacts/custom_loop_usage.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from datasets import load_dataset, concatenate_datasets
2-
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
3-
from torch.utils.data import DataLoader
1+
# Third Party
42
from accelerate import Accelerator, DataLoaderConfiguration
5-
import torch
3+
from datasets import concatenate_datasets, load_dataset
4+
from torch.utils.data import DataLoader
65
from tqdm import tqdm
6+
from transformers import (
7+
AutoModelForCausalLM,
8+
AutoTokenizer,
9+
DataCollatorForLanguageModeling,
10+
)
11+
import torch
12+
13+
# First Party
714
from fms_acceleration_odm import OnlineMixingDataset
815

916
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@@ -17,33 +24,43 @@
1724
tokenizer = AutoTokenizer.from_pretrained(model_name)
1825
tokenizer.pad_token = tokenizer.eos_token
1926

27+
2028
# dataset related
2129
def tokenize_fn(examples):
22-
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
30+
return tokenizer(
31+
examples["text"], truncation=True, padding="max_length", max_length=128
32+
)
33+
2334

2435
dataset_dict = {
2536
"bookcorpus": load_dataset("rojagtap/bookcorpus", split="train[:1%]"),
26-
"wikitext": load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
37+
"wikitext": load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]"),
2738
}
2839

2940
# tokenization
30-
dataset_dict["bookcorpus"] = dataset_dict["bookcorpus"].map(tokenize_fn, batched=True, remove_columns=dataset_dict["bookcorpus"].column_names)
31-
dataset_dict["wikitext"] = dataset_dict["wikitext"].map(tokenize_fn, batched=True, remove_columns=dataset_dict["wikitext"].column_names)
41+
dataset_dict["bookcorpus"] = dataset_dict["bookcorpus"].map(
42+
tokenize_fn, batched=True, remove_columns=dataset_dict["bookcorpus"].column_names
43+
)
44+
dataset_dict["wikitext"] = dataset_dict["wikitext"].map(
45+
tokenize_fn, batched=True, remove_columns=dataset_dict["wikitext"].column_names
46+
)
3247

3348
collator_dict = {
3449
"bookcorpus": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
3550
"wikitext": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
3651
}
3752

3853
# odm related
39-
update_interval=1 # every step
40-
dataset = OnlineMixingDataset(dataset_dict=dataset_dict,
41-
collators_dict=collator_dict,
42-
eval_dataset_dict=None,
43-
eval_collators_dict=None,
44-
output_dir=output_dir,
45-
reward_type="train_loss",
46-
sampling_interval=1)
54+
update_interval = 1 # every step
55+
dataset = OnlineMixingDataset(
56+
dataset_dict=dataset_dict,
57+
collators_dict=collator_dict,
58+
eval_dataset_dict=None,
59+
eval_collators_dict=None,
60+
output_dir=output_dir,
61+
reward_type="train_loss",
62+
sampling_interval=1,
63+
)
4764
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=None)
4865

4966
# distributed setup
@@ -57,7 +74,9 @@ def tokenize_fn(examples):
5774
model.train()
5875

5976
# custom training loop
60-
for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)):
77+
for step, batch in enumerate(
78+
tqdm(dataloader, disable=not accelerator.is_local_main_process)
79+
):
6180
outputs = model(**batch)
6281
loss = outputs.loss
6382
accelerator.backward(loss)

0 commit comments

Comments
 (0)