diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 7a6a11d46..f89ec162e 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -576,10 +576,36 @@ def save(path: str, trainer: SFTTrainer, tc_callback, log_level="WARNING", args= logger.info("Saving tuned model to path: %s", path) trainer.save_model(path) + actually_saved = False if tc_callback and args: - tc_callback.on_save( - args, trainer.state, trainer.control, path=path, is_final=True - ) + if os.path.exists(path): + try: + saved_files = os.listdir(path) + logger.info( + "sanity check, we found %d files at checkpoint path '%s'", + len(saved_files), + path, + ) + actually_saved = len(saved_files) > 0 + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "sanity check, failed to list files in checkpoint path '%s' , error: %s", + path, + e, + ) + else: + logger.warning( + "sanity check, failed because checkpoint path '%s' doesn't exist", path + ) + if actually_saved: + tc_callback.on_save( + args, trainer.state, trainer.control, path=path, is_final=True + ) + else: + logger.warning( + "skip triggering on_save event since checkpoint path is empty: '%s'", + path, + ) def get_parser():