Skip to content

Commit 30b8dd7

Browse files
fix: only trigger on_save if the output checkpoint path actually contains files
Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>
1 parent b3a8a78 commit 30b8dd7

1 file changed

Lines changed: 16 additions & 4 deletions

File tree

tuning/sft_trainer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,23 @@ def save(path: str, trainer: SFTTrainer, tc_callback, log_level="WARNING", args=
576576

577577
logger.info("Saving tuned model to path: %s", path)
578578
trainer.save_model(path)
579+
actually_saved = False
579580
if tc_callback and args:
580-
tc_callback.on_save(
581-
args, trainer.state, trainer.control, path=path, is_final=True
582-
)
583-
581+
if os.path.exists(path):
582+
try:
583+
saved_files = os.listdir(path)
584+
logger.info("sanity check, we found %d files at checkpoint path '%s'", len(saved_files), path)
585+
actually_saved = len(saved_files) > 0
586+
except Exception as e:
587+
logger.warning("sanity check, failed to list files in checkpoint path '%s' , error: %s", path, e)
588+
else:
589+
logger.warning("sanity check, failed because checkpoint path '%s' doesn't exist", path)
590+
if actually_saved:
591+
tc_callback.on_save(
592+
args, trainer.state, trainer.control, path=path, is_final=True
593+
)
594+
else:
595+
logger.warning("skip triggering on_save event since checkpoint path is empty: '%s'", path)
584596

585597
def get_parser():
586598
"""Get the command-line argument parser."""

0 commit comments

Comments
 (0)