1717from typing import Optional , Union
1818import json
1919import os
20+ import sys
2021
2122# Third Party
2223from peft .utils .other import fsdp_auto_wrap_policy
@@ -94,15 +95,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
9495 return
9596
9697 # append the current log to the jsonl file
97- with open (log_file , "a" ) as f :
98+ with open (log_file , "a" , encoding = "utf-8" ) as f :
9899 f .write (f"{ json .dumps (log_obj , sort_keys = True )} \n " )
99100
100101
101102def train (
102103 model_args : configs .ModelArguments ,
103104 data_args : configs .DataArguments ,
104105 train_args : configs .TrainingArguments ,
105- peft_config : Optional [
106+ peft_config : Optional [ # pylint: disable=redefined-outer-name
106107 Union [peft_config .LoraConfig , peft_config .PromptTuningConfig ]
107108 ] = None ,
108109):
@@ -154,9 +155,7 @@ def train(
154155 )
155156
156157 # TODO: understand if we need to hardcode these here or just use defaults in model
157- if isinstance (tokenizer , LlamaTokenizer ) or isinstance (
158- tokenizer , LlamaTokenizerFast
159- ):
158+ if isinstance (tokenizer , (LlamaTokenizer , LlamaTokenizerFast )):
160159 tokenizer .add_special_tokens (
161160 {
162161 "bos_token" : "<s>" ,
@@ -165,33 +164,36 @@ def train(
165164 "pad_token" : "<pad>" ,
166165 }
167166 )
168- elif isinstance (tokenizer , GPTNeoXTokenizerFast ) or isinstance (
169- tokenizer , GPT2Tokenizer
170- ):
167+ elif isinstance (tokenizer , (GPT2Tokenizer , GPTNeoXTokenizerFast )):
171168 tokenizer .add_special_tokens (
172169 {
173170 "pad_token" : "<pad>" ,
174171 }
175172 )
176173
177- """ TODO: near term - how response template ids are parsed out needs to be cleaned.
178- The [2:] here applies if response template has \n prefix, it is needed to strip \n , otherwise template is not found.
179- We will create issue to clean this out after we discuss data formats and collators we will support
180- """
174+ # TODO: near term - how response template ids are parsed out needs to be cleaned.
175+ # The [2:] here applies if response template has \n prefix, it is needed to strip \n,
176+ # otherwise template is not found. We will create issue to clean this out after we discuss
177+ # data formats and collators we will support.
181178 response_template_ids = tokenizer .encode (
182179 data_args .response_template , add_special_tokens = False
183180 )[2 :]
184- # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
185- # as in current main. We need to change name of this parameter we expose to users.
181+ # TODO: This is actually max_seq_length and not model_max_length. we should not override
182+ # model_max_length as in current main. We need to change name of this parameter we expose
183+ # to users.
186184 model_max_length = min (train_args .model_max_length , tokenizer .model_max_length )
187- logger .info (f "Model max length { model_max_length } " )
185+ logger .info ("Model max length %s, model_max_length" )
188186 if train_args .model_max_length > tokenizer .model_max_length :
189187 logger .warning (
190- f"model_max_length { train_args .model_max_length } exceeds tokenizer.model_max_length { tokenizer .model_max_length } , using tokenizer.model_max_length { tokenizer .model_max_length } "
188+ "model_max_length %s exceeds tokenizer.model_max_length \
189+ %s, using tokenizer.model_max_length %s" ,
190+ train_args .model_max_length ,
191+ tokenizer .model_max_length ,
192+ tokenizer .model_max_length ,
191193 )
192194
193195 # TODO: we need to change this, perhaps follow what open instruct does?
194- special_tokens_dict = dict ()
196+ special_tokens_dict = {}
195197 if tokenizer .pad_token is None :
196198 logger .warning ("PAD token set to default, missing in tokenizer" )
197199 special_tokens_dict ["pad_token" ] = configs .DEFAULT_PAD_TOKEN
@@ -219,19 +221,21 @@ def train(
219221 if data_args .validation_data_path :
220222 data_files ["validation" ] = data_args .validation_data_path
221223
222- format_dataset = lambda example : {
224+ format_dataset = lambda example : { # pylint: disable=unnecessary-lambda-assignment
223225 f"{ data_args .dataset_text_field } " : example [f"{ data_args .dataset_text_field } " ]
224226 + tokenizer .eos_token
225227 }
226228
227229 json_dataset = datasets .load_dataset ("json" , data_files = data_files )
228230 formatted_train_dataset = json_dataset ["train" ].map (format_dataset )
229- logger .info (f "Training dataset length is { len (formatted_train_dataset )} " )
231+ logger .info ("Training dataset length is %s" , len (formatted_train_dataset ))
230232
231233 formatted_validation_dataset = None
232234 if data_args .validation_data_path :
233235 formatted_validation_dataset = json_dataset ["validation" ].map (format_dataset )
234- logger .info (f"Validation dataset length is { len (formatted_validation_dataset )} " )
236+ logger .info (
237+ "Validation dataset length is %s" , len (formatted_validation_dataset )
238+ )
235239
236240 aim_callback = get_aimstack_callback ()
237241 file_logger_callback = FileLoggingCallback (logger )
@@ -248,13 +252,13 @@ def train(
248252 logger .error (
249253 "Error, response template is None, needs to be set for training"
250254 )
251- exit (- 1 )
255+ sys . exit (- 1 )
252256
253257 if data_args .dataset_text_field is None :
254258 logger .error (
255259 "Error, dataset_text_field is None, needs to be set for training"
256260 )
257- exit (- 1 )
261+ sys . exit (- 1 )
258262
259263 data_collator = DataCollatorForCompletionOnlyLM (
260264 response_template_ids ,
@@ -284,7 +288,7 @@ def train(
284288 trainer .train ()
285289
286290
287- def main (** kwargs ):
291+ def main (** kwargs ): # pylint: disable=unused-argument
288292 parser = transformers .HfArgumentParser (
289293 dataclass_types = (
290294 configs .ModelArguments ,
0 commit comments