1- from datasets import load_dataset , concatenate_datasets
2- from transformers import AutoTokenizer , AutoModelForCausalLM , DataCollatorForLanguageModeling
3- from torch .utils .data import DataLoader
1+ # Third Party
42from accelerate import Accelerator , DataLoaderConfiguration
5- import torch
3+ from datasets import concatenate_datasets , load_dataset
4+ from torch .utils .data import DataLoader
65from tqdm import tqdm
6+ from transformers import (
7+ AutoModelForCausalLM ,
8+ AutoTokenizer ,
9+ DataCollatorForLanguageModeling ,
10+ )
11+ import torch
12+
13+ # First Party
714from fms_acceleration_odm import OnlineMixingDataset
815
916model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
1724tokenizer = AutoTokenizer .from_pretrained (model_name )
1825tokenizer .pad_token = tokenizer .eos_token
1926
27+
2028# dataset related
2129def tokenize_fn (examples ):
22- return tokenizer (examples ["text" ], truncation = True , padding = "max_length" , max_length = 128 )
30+ return tokenizer (
31+ examples ["text" ], truncation = True , padding = "max_length" , max_length = 128
32+ )
33+
2334
2435dataset_dict = {
2536 "bookcorpus" : load_dataset ("rojagtap/bookcorpus" , split = "train[:1%]" ),
26- "wikitext" : load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "train[:1%]" )
37+ "wikitext" : load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "train[:1%]" ),
2738}
2839
2940# tokenization
30- dataset_dict ["bookcorpus" ] = dataset_dict ["bookcorpus" ].map (tokenize_fn , batched = True , remove_columns = dataset_dict ["bookcorpus" ].column_names )
31- dataset_dict ["wikitext" ] = dataset_dict ["wikitext" ].map (tokenize_fn , batched = True , remove_columns = dataset_dict ["wikitext" ].column_names )
41+ dataset_dict ["bookcorpus" ] = dataset_dict ["bookcorpus" ].map (
42+ tokenize_fn , batched = True , remove_columns = dataset_dict ["bookcorpus" ].column_names
43+ )
44+ dataset_dict ["wikitext" ] = dataset_dict ["wikitext" ].map (
45+ tokenize_fn , batched = True , remove_columns = dataset_dict ["wikitext" ].column_names
46+ )
3247
3348collator_dict = {
3449 "bookcorpus" : DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm = False ),
3550 "wikitext" : DataCollatorForLanguageModeling (tokenizer = tokenizer , mlm = False ),
3651}
3752
3853# odm related
39- update_interval = 1 # every step
40- dataset = OnlineMixingDataset (dataset_dict = dataset_dict ,
41- collators_dict = collator_dict ,
42- eval_dataset_dict = None ,
43- eval_collators_dict = None ,
44- output_dir = output_dir ,
45- reward_type = "train_loss" ,
46- sampling_interval = 1 )
54+ update_interval = 1 # every step
55+ dataset = OnlineMixingDataset (
56+ dataset_dict = dataset_dict ,
57+ collators_dict = collator_dict ,
58+ eval_dataset_dict = None ,
59+ eval_collators_dict = None ,
60+ output_dir = output_dir ,
61+ reward_type = "train_loss" ,
62+ sampling_interval = 1 ,
63+ )
4764dataloader = DataLoader (dataset , batch_size = 2 , shuffle = False , collate_fn = None )
4865
4966# distributed setup
@@ -57,7 +74,9 @@ def tokenize_fn(examples):
5774model .train ()
5875
5976# custom training loop
60- for step , batch in enumerate (tqdm (dataloader , disable = not accelerator .is_local_main_process )):
77+ for step , batch in enumerate (
78+ tqdm (dataloader , disable = not accelerator .is_local_main_process )
79+ ):
6180 outputs = model (** batch )
6281 loss = outputs .loss
6382 accelerator .backward (loss )
0 commit comments