Skip to content

Commit 2fb6e6a

Browse files
committed
Fix for preprocessor_config in checkpoint folder
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent 168cde0 commit 2fb6e6a

3 files changed

Lines changed: 4 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"sentencepiece>=0.1.99,<0.3",
3535
"tokenizers>=0.13.3,<1.0",
3636
"tqdm>=4.66.2,<5.0",
37-
"trl>=0.13,<0.15",
37+
"trl>=0.13,<0.16",
3838
"peft>=0.8.0,<0.14",
3939
"protobuf>=5.28.0,<6.0.0",
4040
"datasets>=2.15.0,<4.0",

tuning/data/setup_dataprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,5 +485,5 @@ def process_dataargs(
485485
dataset_text_field,
486486
data_collator,
487487
max_seq_length,
488-
None,
488+
dataset_kwargs,
489489
)

tuning/sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def train(
229229
attn_implementation="flash_attention_2"
230230
if model_args.use_flash_attn
231231
else None,
232-
use_cache=(not train_args.gradient_checkpointing),
233232
)
234233

235234
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)
@@ -260,6 +259,7 @@ def train(
260259
cache_dir=train_args.cache_dir,
261260
use_fast=True,
262261
legacy=True,
262+
use_cache=(not train_args.gradient_checkpointing),
263263
)
264264
except Exception as e: # pylint: disable=broad-except
265265
logger.error(traceback.format_exc())
@@ -373,7 +373,7 @@ def train(
373373

374374
trainer = SFTTrainer(
375375
model=model,
376-
processing_class=tokenizer,
376+
processing_class=tokenizer if processor is None else processor,
377377
train_dataset=formatted_train_dataset,
378378
eval_dataset=formatted_validation_dataset,
379379
data_collator=data_collator,

0 commit comments

Comments
 (0)