11import os
2+ import time
23from typing import Optional , Union
34
45import datasets
1011from transformers .utils import logging
1112from transformers import TrainerCallback
1213from trl import SFTTrainer , DataCollatorForCompletionOnlyLM
13- from tuning .aim_loader import get_aimstack_callback
14- from tuning .config import configs , peft_config
14+ from tuning .config import configs , peft_config , tracker_configs
1515from tuning .data import tokenizer_data_utils
1616from tuning .utils .config_utils import get_hf_peft_config
1717from tuning .utils .data_type_utils import get_torch_dtype
18+ from tuning .tracker .tracker import Tracker
19+ from tuning .tracker .aimstack_tracker import AimStackTracker
1820
1921class PeftSavingCallback (TrainerCallback ):
2022 def on_save (self , args , state , control , ** kwargs ):
@@ -24,13 +26,13 @@ def on_save(self, args, state, control, **kwargs):
2426 if "pytorch_model.bin" in os .listdir (checkpoint_path ):
2527 os .remove (os .path .join (checkpoint_path , "pytorch_model.bin" ))
2628
27-
28-
2929def train (
3030 model_args : configs .ModelArguments ,
3131 data_args : configs .DataArguments ,
3232 train_args : configs .TrainingArguments ,
3333 peft_config : Optional [Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]] = None ,
34+ tracker_name : Optional [str ] = None ,
35+ tracker_config : Optional [Union [tracker_configs .AimConfig ]] = None
3436):
3537 """Call the SFTTrainer
3638
@@ -44,7 +46,6 @@ def train(
4446 The peft configuration to pass to trainer
4547 """
4648 run_distributed = int (os .environ .get ("WORLD_SIZE" , "1" )) > 1
47-
4849 logger = logging .get_logger ("sft_trainer" )
4950
5051 # Validate parameters
@@ -58,14 +59,29 @@ def train(
5859 train_args .fsdp = ""
5960 train_args .fsdp_config = {'xla' :False }
6061
62+ # Initialize the tracker early so we can calculate custom metrics like model_load_time.
63+
64+ if tracker_name == 'aim' :
65+ if tracker_config is not None :
66+ tracker = AimStackTracker (tracker_config )
67+ else :
68+ logger .error ("Tracker name is set to " + tracker_name + " but config is None." )
69+ else :
70+ logger .info ('No tracker set so just set a dummy API which does nothing' )
71+ tracker = Tracker ()
72+
6173 task_type = "CAUSAL_LM"
74+
75+ model_load_time = time .time ()
6276 model = AutoModelForCausalLM .from_pretrained (
6377 model_args .model_name_or_path ,
6478 cache_dir = train_args .cache_dir ,
6579 torch_dtype = get_torch_dtype (model_args .torch_dtype ),
6680 use_flash_attention_2 = model_args .use_flash_attn ,
6781 )
68-
82+ model_load_time = time .time () - model_load_time
83+ tracker .track (metric = model_load_time , name = 'model_load_time' )
84+
6985 peft_config = get_hf_peft_config (task_type , peft_config )
7086
7187 model .gradient_checkpointing_enable ()
@@ -130,8 +146,12 @@ def train(
130146 formatted_dataset = json_dataset ['train' ].map (lambda example : {f"{ data_args .dataset_text_field } " : example [f"{ data_args .dataset_text_field } " ] + tokenizer .eos_token })
131147 logger .info (f"Dataset length is { len (formatted_dataset )} " )
132148
133- aim_callback = get_aimstack_callback ()
134- callbacks = [aim_callback ,PeftSavingCallback ()]
149+ # club and register callbacks
150+ callbacks = [PeftSavingCallback ()]
151+
152+ tracker_callback = tracker .get_hf_callback ()
153+ if tracker_callback is not None :
154+ callbacks .append (tracker_callback )
135155
136156 if train_args .packing :
137157 logger .info ("Packing is set to True" )
@@ -173,16 +193,30 @@ def main(**kwargs):
173193 configs .DataArguments ,
174194 configs .TrainingArguments ,
175195 peft_config .LoraConfig ,
176- peft_config .PromptTuningConfig ))
196+ peft_config .PromptTuningConfig ,
197+ tracker_configs .AimConfig ))
177198 parser .add_argument ('--peft_method' , type = str .lower , choices = ['pt' , 'lora' , None , 'none' ], default = "pt" )
178- model_args , data_args , training_args , lora_config , prompt_tuning_config , peft_method , _ = parser .parse_args_into_dataclasses (return_remaining_strings = True )
179- if peft_method .peft_method == "lora" :
199+ parser .add_argument ('--tracker' , type = str .lower , choices = ['aim' , None , 'none' ], default = "aim" )
200+ (model_args , data_args , training_args ,
201+ lora_config , prompt_tuning_config , aim_config ,
202+ additional , _ ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
203+
204+ peft_method = additional .peft_method
205+ tracker_name = additional .tracker
206+
207+ if peft_method == "lora" :
180208 tune_config = lora_config
181- elif peft_method . peft_method == "pt" :
209+ elif peft_method == "pt" :
182210 tune_config = prompt_tuning_config
183211 else :
184212 tune_config = None
185- train (model_args , data_args , training_args , tune_config )
213+
214+ if tracker_name == "aim" :
215+ tracker_config = aim_config
216+ else :
217+ tracker_config = None
218+
219+ train (model_args , data_args , training_args , tune_config , tracker_name , tracker_config )
186220
187221if __name__ == "__main__" :
188222 fire .Fire (main )
0 commit comments