11# Standard
22from datetime import datetime
3- from typing import Optional , Union , List , Dict
3+ from typing import Dict , List , Optional , Union
44import json
5- import os , time
5+ import os
6+ import time
67
78# Third Party
8- import transformers
9+ from peft . utils . other import fsdp_auto_wrap_policy
910from transformers import (
1011 AutoModelForCausalLM ,
1112 AutoTokenizer ,
1617 TrainerCallback ,
1718)
1819from transformers .utils import logging
19- from peft .utils .other import fsdp_auto_wrap_policy
2020from trl import DataCollatorForCompletionOnlyLM , SFTTrainer
2121import datasets
2222import fire
23+ import transformers
2324
2425# Local
2526from tuning .config import configs , peft_config , tracker_configs
2627from tuning .data import tokenizer_data_utils
27- from tuning .utils .config_utils import get_hf_peft_config
28- from tuning .utils .data_type_utils import get_torch_dtype
2928from tuning .trackers .tracker import Tracker
3029from tuning .trackers .tracker_factory import get_tracker
30+ from tuning .utils .config_utils import get_hf_peft_config
31+ from tuning .utils .data_type_utils import get_torch_dtype
3132
3233logger = logging .get_logger ("sft_trainer" )
3334
35+
3436class PeftSavingCallback (TrainerCallback ):
3537 def on_save (self , args , state , control , ** kwargs ):
3638 checkpoint_path = os .path .join (
@@ -41,6 +43,7 @@ def on_save(self, args, state, control, **kwargs):
4143 if "pytorch_model.bin" in os .listdir (checkpoint_path ):
4244 os .remove (os .path .join (checkpoint_path , "pytorch_model.bin" ))
4345
46+
4447class FileLoggingCallback (TrainerCallback ):
4548 """Exports metrics, e.g., training loss to a file in the checkpoint directory."""
4649
@@ -84,6 +87,7 @@ def _track_loss(self, loss_key, log_file, logs, state):
8487 with open (log_file , "a" ) as f :
8588 f .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
8689
90+
8791def train (
8892 model_args : configs .ModelArguments ,
8993 data_args : configs .DataArguments ,
@@ -93,7 +97,7 @@ def train(
9397 ] = None ,
9498 callbacks : Optional [List [TrainerCallback ]] = None ,
9599 tracker : Optional [Tracker ] = None ,
96- exp_metadata : Optional [Dict ] = None
100+ exp_metadata : Optional [Dict ] = None ,
97101):
98102 """Call the SFTTrainer
99103
@@ -105,6 +109,11 @@ def train(
105109 peft_config.PromptTuningConfig for prompt tuning | \
106110 None for fine tuning
107111 The peft configuration to pass to trainer
112+ callbacks: List of callbacks to attach with SFTtrainer.
113+ tracker: One of the available trackers in tuning.trackers.tracker_factory.REGISTERED_TRACKERS
114+ Initialized using tuning.trackers.tracker_factory.get_tracker
115+ Using configs in tuning.config.tracker_configs
116+ exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker.
108117 """
109118 run_distributed = int (os .environ .get ("WORLD_SIZE" , "1" )) > 1
110119
@@ -133,7 +142,7 @@ def train(
133142 torch_dtype = get_torch_dtype (model_args .torch_dtype ),
134143 use_flash_attention_2 = model_args .use_flash_attn ,
135144 )
136- additional_metrics [' model_load_time' ] = time .time () - model_load_time
145+ additional_metrics [" model_load_time" ] = time .time () - model_load_time
137146
138147 peft_config = get_hf_peft_config (task_type , peft_config )
139148
@@ -269,16 +278,17 @@ def train(
269278 if tracker is not None :
270279 # Currently tracked only on process zero.
271280 if trainer .is_world_process_zero ():
272- for k ,v in additional_metrics .items ():
273- tracker .track (metric = v , name = k , stage = ' additional_metrics' )
274- tracker .set_params (params = exp_metadata , name = ' experiment_metadata' )
281+ for k , v in additional_metrics .items ():
282+ tracker .track (metric = v , name = k , stage = " additional_metrics" )
283+ tracker .set_params (params = exp_metadata , name = " experiment_metadata" )
275284
276285 if run_distributed and peft_config is not None :
277286 trainer .accelerator .state .fsdp_plugin .auto_wrap_policy = fsdp_auto_wrap_policy (
278287 model
279288 )
280289 trainer .train ()
281290
291+
282292def main (** kwargs ):
283293 parser = transformers .HfArgumentParser (
284294 dataclass_types = (
@@ -300,6 +310,7 @@ def main(**kwargs):
300310 "--exp_metadata" ,
301311 type = str ,
302312 default = None ,
313+ help = 'Pass a json string representing K:V pairs to be associated to the tuning run in the tracker. e.g. \' {"gpu":"A100-80G"}\' ' ,
303314 )
304315 (
305316 model_args ,
@@ -313,18 +324,18 @@ def main(**kwargs):
313324 ) = parser .parse_args_into_dataclasses (return_remaining_strings = True )
314325
315326 peft_method = additional .peft_method
316- if peft_method == "lora" :
317- tune_config = lora_config
318- elif peft_method == "pt" :
319- tune_config = prompt_tuning_config
327+ if peft_method == "lora" :
328+ tune_config = lora_config
329+ elif peft_method == "pt" :
330+ tune_config = prompt_tuning_config
320331 else :
321- tune_config = None
332+ tune_config = None
322333
323334 tracker_name = training_args .tracker
324335 if tracker_name == "aim" :
325- tracker_config = aim_config
336+ tracker_config = aim_config
326337 else :
327- tracker_config = None
338+ tracker_config = None
328339
329340 # Initialize callbacks
330341 file_logger_callback = FileLoggingCallback (logger )
@@ -343,7 +354,9 @@ def main(**kwargs):
343354 try :
344355 metadata = json .loads (additional .exp_metadata )
345356 if metadata is None or not isinstance (metadata , Dict ):
346- logger .warning ('metadata cannot be converted to simple k:v dict ignoring' )
357+ logger .warning (
358+ "metadata cannot be converted to simple k:v dict ignoring"
359+ )
347360 metadata = None
348361 except :
349362 logger .error ("failed while parsing extra metadata. pass a valid json" )
@@ -355,8 +368,9 @@ def main(**kwargs):
355368 peft_config = tune_config ,
356369 callbacks = callbacks ,
357370 tracker = tracker ,
358- exp_metadata = metadata
371+ exp_metadata = metadata ,
359372 )
360373
374+
361375if __name__ == "__main__" :
362376 fire .Fire (main )
0 commit comments