Skip to content

Commit 3f2643c

Browse files
fix: only trigger on_save if the output checkpoint path actually contains files
Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>
1 parent b3a8a78 commit 3f2643c

1 file changed

Lines changed: 29 additions & 3 deletions

File tree

tuning/sft_trainer.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,10 +576,36 @@ def save(path: str, trainer: SFTTrainer, tc_callback, log_level="WARNING", args=
576576

577577
logger.info("Saving tuned model to path: %s", path)
578578
trainer.save_model(path)
579+
actually_saved = False
579580
if tc_callback and args:
580-
tc_callback.on_save(
581-
args, trainer.state, trainer.control, path=path, is_final=True
582-
)
581+
if os.path.exists(path):
582+
try:
583+
saved_files = os.listdir(path)
584+
logger.info(
585+
"sanity check, we found %d files at checkpoint path '%s'",
586+
len(saved_files),
587+
path,
588+
)
589+
actually_saved = len(saved_files) > 0
590+
except Exception as e: # pylint: disable=broad-exception-caught
591+
logger.error(
592+
"sanity check, failed to list files in checkpoint path '%s' , error: %s",
593+
path,
594+
e,
595+
)
596+
else:
597+
logger.warning(
598+
"sanity check, failed because checkpoint path '%s' doesn't exist", path
599+
)
600+
if actually_saved:
601+
tc_callback.on_save(
602+
args, trainer.state, trainer.control, path=path, is_final=True
603+
)
604+
else:
605+
logger.warning(
606+
"skip triggering on_save event since checkpoint path is empty: '%s'",
607+
path,
608+
)
583609

584610

585611
def get_parser():

0 commit comments

Comments
 (0)