|
| 1 | +from datasets import load_dataset, concatenate_datasets |
| 2 | +from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling |
| 3 | +from torch.utils.data import DataLoader |
| 4 | +from accelerate import Accelerator, DataLoaderConfiguration |
| 5 | +import torch |
| 6 | +from tqdm import tqdm |
| 7 | +from fms_acceleration_odm import OnlineMixingDataset |
| 8 | + |
| 9 | +model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| 10 | +output_dir = "./odm_custom_use" |
| 11 | +max_steps = 50 |
| 12 | + |
| 13 | +# model |
| 14 | +model = AutoModelForCausalLM.from_pretrained(model_name) |
| 15 | + |
| 16 | +# tokenizer |
| 17 | +tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 18 | +tokenizer.pad_token = tokenizer.eos_token |
| 19 | + |
| 20 | +# dataset related |
| 21 | +def tokenize_fn(examples): |
| 22 | + return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128) |
| 23 | + |
| 24 | +dataset_dict = { |
| 25 | + "bookcorpus": load_dataset("rojagtap/bookcorpus", split="train[:1%]"), |
| 26 | + "wikitext": load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]") |
| 27 | +} |
| 28 | + |
| 29 | +# tokenization |
| 30 | +dataset_dict["bookcorpus"] = dataset_dict["bookcorpus"].map(tokenize_fn, batched=True, remove_columns=dataset_dict["bookcorpus"].column_names) |
| 31 | +dataset_dict["wikitext"] = dataset_dict["wikitext"].map(tokenize_fn, batched=True, remove_columns=dataset_dict["wikitext"].column_names) |
| 32 | + |
| 33 | +collator_dict = { |
| 34 | + "bookcorpus": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), |
| 35 | + "wikitext": DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), |
| 36 | +} |
| 37 | + |
| 38 | +# odm related |
| 39 | +update_interval=1 # every step |
| 40 | +dataset = OnlineMixingDataset(dataset_dict=dataset_dict, |
| 41 | + collators_dict=collator_dict, |
| 42 | + eval_dataset_dict=None, |
| 43 | + eval_collators_dict=None, |
| 44 | + output_dir=output_dir, |
| 45 | + reward_type="train_loss", |
| 46 | + sampling_interval=1) |
| 47 | +dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=None) |
| 48 | + |
| 49 | +# distributed setup |
| 50 | +dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True) |
| 51 | +accelerator = Accelerator(split_batches=True, dataloader_config=dataloader_config) |
| 52 | +model, dataloader = accelerator.prepare(model, dataloader) |
| 53 | + |
| 54 | +# training setup |
| 55 | +optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) |
| 56 | + |
| 57 | +model.train() |
| 58 | + |
| 59 | +# custom training loop |
| 60 | +for step, batch in enumerate(tqdm(dataloader, disable=not accelerator.is_local_main_process)): |
| 61 | + outputs = model(**batch) |
| 62 | + loss = outputs.loss |
| 63 | + accelerator.backward(loss) |
| 64 | + optimizer.step() |
| 65 | + optimizer.zero_grad() |
| 66 | + if step % 1 == 0: |
| 67 | + print(f"Step {step} | Loss: {loss.item():.4f}") |
| 68 | + max_steps -= 1 |
| 69 | + if max_steps == 0: |
| 70 | + break |
| 71 | + |
| 72 | +print("training completed!") |
0 commit comments