11# Standard
22from datetime import datetime
3- from typing import Optional , Union
3+ from typing import Optional , Union , List
44import json
55import os , time
66
77# Third Party
8- from peft . utils . other import fsdp_auto_wrap_policy
8+ import transformers
99from transformers import (
1010 AutoModelForCausalLM ,
1111 AutoTokenizer ,
1616 TrainerCallback ,
1717)
1818from transformers .utils import logging
19+ from peft .utils .other import fsdp_auto_wrap_policy
1920from trl import DataCollatorForCompletionOnlyLM , SFTTrainer
2021import datasets
2122import fire
22- import transformers
2323
2424# Local
2525from tuning .config import configs , peft_config , tracker_configs
2929from tuning .tracker .tracker import Tracker
3030from tuning .tracker .aimstack_tracker import AimStackTracker
3131
32+ logger = logging .get_logger ("sft_trainer" )
3233
3334class PeftSavingCallback (TrainerCallback ):
3435 def on_save (self , args , state , control , ** kwargs ):
@@ -83,15 +84,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
8384 with open (log_file , "a" ) as f :
8485 f .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
8586
86-
8787def train (
8888 model_args : configs .ModelArguments ,
8989 data_args : configs .DataArguments ,
9090 train_args : configs .TrainingArguments ,
9191 peft_config : Optional [
9292 Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]
9393 ] = None ,
94- tracker_config : Optional [Union [tracker_configs .AimConfig ]] = None
94+ callbacks : Optional [List [TrainerCallback ]] = None ,
95+ tracker : Optional [Tracker ] = None ,
9596):
9697 """Call the SFTTrainer
9798
@@ -105,7 +106,6 @@ def train(
105106 The peft configuration to pass to trainer
106107 """
107108 run_distributed = int (os .environ .get ("WORLD_SIZE" , "1" )) > 1
108- logger = logging .get_logger ("sft_trainer" )
109109
110110 # Validate parameters
111111 if (not isinstance (train_args .num_train_epochs , float )) or (
@@ -122,17 +122,6 @@ def train(
122122 train_args .fsdp = ""
123123 train_args .fsdp_config = {"xla" : False }
124124
125- # Initialize the tracker early so we can calculate custom metrics like model_load_time.
126- tracker_name = train_args .tracker
127- if tracker_name == 'aim' :
128- if tracker_config is not None :
129- tracker = AimStackTracker (tracker_config )
130- else :
131- logger .error ("Tracker name is set to " + tracker_name + " but config is None." )
132- else :
133- logger .info ('No tracker set so just set a dummy API which does nothing' )
134- tracker = Tracker ()
135-
136125 task_type = "CAUSAL_LM"
137126
138127 model_load_time = time .time ()
@@ -259,15 +248,6 @@ def train(
259248 )
260249 packing = False
261250
262- # club and register callbacks
263- file_logger_callback = FileLoggingCallback (logger )
264- peft_saving_callback = PeftSavingCallback ()
265- callbacks = [peft_saving_callback , file_logger_callback ]
266-
267- tracker_callback = tracker .get_hf_callback ()
268- if tracker_callback is not None :
269- callbacks .append (tracker_callback )
270-
271251 trainer = SFTTrainer (
272252 model = model ,
273253 tokenizer = tokenizer ,
@@ -288,7 +268,6 @@ def train(
288268 )
289269 trainer .train ()
290270
291-
292271def main (** kwargs ):
293272 parser = transformers .HfArgumentParser (
294273 dataclass_types = (
@@ -331,8 +310,26 @@ def main(**kwargs):
331310 else :
332311 tracker_config = None
333312
334- train (model_args , data_args , training_args , tune_config , tracker_config )
313+ # Initialize the tracker early so we can calculate custom metrics like model_load_time.
314+ tracker_name = training_args .tracker
315+ if tracker_name == 'aim' :
316+ if tracker_config is not None :
317+ tracker = AimStackTracker (tracker_config )
318+ else :
319+ logger .error ("Tracker name is set to " + tracker_name + " but config is None." )
320+ else :
321+ tracker = Tracker ()
322+
323+ # Initialize callbacks
324+ file_logger_callback = FileLoggingCallback (logger )
325+ peft_saving_callback = PeftSavingCallback ()
326+ callbacks = [peft_saving_callback , file_logger_callback ]
327+
328+ tracker_callback = tracker .get_hf_callback ()
329+ if tracker_callback is not None :
330+ callbacks .append (tracker_callback )
335331
332+ train (model_args , data_args , training_args , tune_config , callbacks , tracker )
336333
337334if __name__ == "__main__" :
338335 fire .Fire (main )
0 commit comments