diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index c093a4d819..fd759229c7 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -513,6 +513,7 @@ def train( tc_callback = TrainerControllerCallback( trainer_controller_args.trainer_controller_config_file, ) + tc_callback.on_init_end(trainer.args, trainer.state, trainer.control) trainer.add_callback(tc_callback) trainer.train(resume_from_checkpoint)