File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -613,16 +613,13 @@ def single_model_finetune(
613613
614614 if init_frz_model is not None :
615615 frz_model = torch .jit .load (init_frz_model , map_location = DEVICE )
616- try :
617- self .model .load_state_dict (frz_model .state_dict ())
618- except RuntimeError as err_msg :
619- if "Missing key(s) in state_dict" in str (
620- err_msg
621- ) or "Unexpected key(s) in state_dict" in str (err_msg ):
622- self .model .load_state_dict (frz_model .state_dict (), strict = False )
623- log .warning ("Loaded with strict=False to ignore non-matching keys." )
624- else :
625- raise
616+ state = frz_model .state_dict ()
617+ missing , unexpected = self .model .load_state_dict (state , strict = False )
618+ if missing or unexpected :
619+ log .warning (
620+ "Checkpoint loaded non-strictly. "
621+ f"Missing keys: { missing } , Unexpected keys: { unexpected } "
622+ )
626623
627624 # Get model prob for multi-task
628625 if self .multi_task :
You can’t perform that action at this time.
0 commit comments