33from typing import Optional , Union
44import json
55import os
6+ import sys
67
78# Third Party
89from peft .utils .other import fsdp_auto_wrap_policy
@@ -80,15 +81,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
8081 return
8182
8283 # append the current log to the jsonl file
83- with open (log_file , "a" ) as f :
84+ with open (log_file , "a" , encoding = "utf-8" ) as f :
8485 f .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
8586
8687
8788def train (
8889 model_args : configs .ModelArguments ,
8990 data_args : configs .DataArguments ,
9091 train_args : configs .TrainingArguments ,
91- peft_config : Optional [
92+ peft_configs : Optional [
9293 Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]
9394 ] = None ,
9495):
@@ -98,7 +99,7 @@ def train(
9899 model_args: tuning.config.configs.ModelArguments
99100 data_args: tuning.config.configs.DataArguments
100101 train_args: tuning.config.configs.TrainingArguments
101- peft_config : peft_config.LoraConfig for Lora tuning | \
102+ peft_configs : peft_config.LoraConfig for Lora tuning | \
102103 peft_config.PromptTuningConfig for prompt tuning | \
103104 None for fine tuning
104105 The peft configuration to pass to trainer
@@ -130,7 +131,7 @@ def train(
130131 use_flash_attention_2 = model_args .use_flash_attn ,
131132 )
132133
133- peft_config = get_hf_peft_config (task_type , peft_config )
134+ peft_configs = get_hf_peft_config (task_type , peft_configs )
134135
135136 model .gradient_checkpointing_enable ()
136137
@@ -140,9 +141,7 @@ def train(
140141 )
141142
142143 # TODO: understand if we need to hardcode these here or just use defaults in model
143- if isinstance (tokenizer , LlamaTokenizer ) or isinstance (
144- tokenizer , LlamaTokenizerFast
145- ):
144+ if isinstance (tokenizer , (LlamaTokenizer , LlamaTokenizerFast )):
146145 tokenizer .add_special_tokens (
147146 {
148147 "bos_token" : "<s>" ,
@@ -151,33 +150,36 @@ def train(
151150 "pad_token" : "<pad>" ,
152151 }
153152 )
154- elif isinstance (tokenizer , GPTNeoXTokenizerFast ) or isinstance (
155- tokenizer , GPT2Tokenizer
156- ):
153+ elif isinstance (tokenizer , (GPT2Tokenizer , GPTNeoXTokenizerFast )):
157154 tokenizer .add_special_tokens (
158155 {
159156 "pad_token" : "<pad>" ,
160157 }
161158 )
162159
163- """ TODO: near term - how response template ids are parsed out needs to be cleaned.
164- The [2:] here applies if response template has \n prefix, it is needed to strip \n , otherwise template is not found.
165- We will create issue to clean this out after we discuss data formats and collators we will support
166- """
160+ # TODO: near term - how response template ids are parsed out needs to be cleaned.
161+ # The [2:] here applies if response template has \n prefix, it is needed to strip \n,
162+ # otherwise template is not found. We will create issue to clean this out after we discuss
163+ # data formats and collators we will support.
167164 response_template_ids = tokenizer .encode (
168165 data_args .response_template , add_special_tokens = False
169166 )[2 :]
170- # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
171- # as in current main. We need to change name of this parameter we expose to users.
167+ # TODO: This is actually max_seq_length and not model_max_length. we should not override
168+ # model_max_length as in current main. We need to change name of this parameter we expose
169+ # to users.
172170 model_max_length = min (train_args .model_max_length , tokenizer .model_max_length )
173- logger .info (f "Model max length { model_max_length } " )
171+ logger .info ("Model max length %s, model_max_length" )
174172 if train_args .model_max_length > tokenizer .model_max_length :
175173 logger .warning (
176- f"model_max_length { train_args .model_max_length } exceeds tokenizer.model_max_length { tokenizer .model_max_length } , using tokenizer.model_max_length { tokenizer .model_max_length } "
174+ "model_max_length %s exceeds tokenizer.model_max_length \
175+ %s, using tokenizer.model_max_length %s" ,
176+ train_args .model_max_length ,
177+ tokenizer .model_max_length ,
178+ tokenizer .model_max_length ,
177179 )
178180
179181 # TODO: we need to change this, perhaps follow what open instruct does?
180- special_tokens_dict = dict ()
182+ special_tokens_dict = {}
181183 if tokenizer .pad_token is None :
182184 logger .warning ("PAD token set to default, missing in tokenizer" )
183185 special_tokens_dict ["pad_token" ] = configs .DEFAULT_PAD_TOKEN
@@ -205,19 +207,21 @@ def train(
205207 if data_args .validation_data_path :
206208 data_files ["validation" ] = data_args .validation_data_path
207209
208- format_dataset = lambda example : {
210+ format_dataset = lambda example : { # pylint: disable=unnecessary-lambda-assignment
209211 f"{ data_args .dataset_text_field } " : example [f"{ data_args .dataset_text_field } " ]
210212 + tokenizer .eos_token
211213 }
212214
213215 json_dataset = datasets .load_dataset ("json" , data_files = data_files )
214216 formatted_train_dataset = json_dataset ["train" ].map (format_dataset )
215- logger .info (f "Training dataset length is { len (formatted_train_dataset )} " )
217+ logger .info ("Training dataset length is %s" , len (formatted_train_dataset ))
216218
217219 formatted_validation_dataset = None
218220 if data_args .validation_data_path :
219221 formatted_validation_dataset = json_dataset ["validation" ].map (format_dataset )
220- logger .info (f"Validation dataset length is { len (formatted_validation_dataset )} " )
222+ logger .info (
223+ "Validation dataset length is %s" , len (formatted_validation_dataset )
224+ )
221225
222226 aim_callback = get_aimstack_callback ()
223227 file_logger_callback = FileLoggingCallback (logger )
@@ -234,13 +238,13 @@ def train(
234238 logger .error (
235239 "Error, response template is None, needs to be set for training"
236240 )
237- exit (- 1 )
241+ sys . exit (- 1 )
238242
239243 if data_args .dataset_text_field is None :
240244 logger .error (
241245 "Error, dataset_text_field is None, needs to be set for training"
242246 )
243- exit (- 1 )
247+ sys . exit (- 1 )
244248
245249 data_collator = DataCollatorForCompletionOnlyLM (
246250 response_template_ids ,
@@ -260,17 +264,17 @@ def train(
260264 args = train_args ,
261265 max_seq_length = model_max_length ,
262266 callbacks = callbacks ,
263- peft_config = peft_config ,
267+ peft_config = peft_configs ,
264268 )
265269
266- if run_distributed and peft_config is not None :
270+ if run_distributed and peft_configs is not None :
267271 trainer .accelerator .state .fsdp_plugin .auto_wrap_policy = fsdp_auto_wrap_policy (
268272 model
269273 )
270274 trainer .train ()
271275
272276
273- def main (** kwargs ):
277+ def main (** kwargs ): # pylint: disable=unused-argument
274278 parser = transformers .HfArgumentParser (
275279 dataclass_types = (
276280 configs .ModelArguments ,
@@ -286,7 +290,7 @@ def main(**kwargs):
286290 choices = ["pt" , "lora" , None , "none" ],
287291 default = "pt" ,
288292 )
289- (
293+ ( # pylint: disable=unbalanced-tuple-unpacking
290294 model_args ,
291295 data_args ,
292296 training_args ,
0 commit comments