File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -34,7 +34,7 @@ dependencies = [
3434" sentencepiece>=0.1.99,<0.3" ,
3535" tokenizers>=0.13.3,<1.0" ,
3636" tqdm>=4.66.2,<5.0" ,
37- " trl>=0.13,<0.15 " ,
37+ " trl>=0.13,<0.16 " ,
3838" peft>=0.8.0,<0.14" ,
3939" protobuf>=5.28.0,<6.0.0" ,
4040" datasets>=2.15.0,<4.0" ,
Original file line number Diff line number Diff line change @@ -485,5 +485,5 @@ def process_dataargs(
485485 dataset_text_field ,
486486 data_collator ,
487487 max_seq_length ,
488- None ,
488+ dataset_kwargs ,
489489 )
Original file line number Diff line number Diff line change @@ -229,7 +229,6 @@ def train(
229229 attn_implementation = "flash_attention_2"
230230 if model_args .use_flash_attn
231231 else None ,
232- use_cache = (not train_args .gradient_checkpointing ),
233232 )
234233
235234 processor = AutoProcessor .from_pretrained (model_args .model_name_or_path )
@@ -260,6 +259,7 @@ def train(
260259 cache_dir = train_args .cache_dir ,
261260 use_fast = True ,
262261 legacy = True ,
262+ use_cache = (not train_args .gradient_checkpointing ),
263263 )
264264 except Exception as e : # pylint: disable=broad-except
265265 logger .error (traceback .format_exc ())
@@ -373,7 +373,7 @@ def train(
373373
374374 trainer = SFTTrainer (
375375 model = model ,
376- processing_class = tokenizer ,
376+ processing_class = tokenizer if processor is None else processor ,
377377 train_dataset = formatted_train_dataset ,
378378 eval_dataset = formatted_validation_dataset ,
379379 data_collator = data_collator ,
You can’t perform that action at this time.
0 commit comments