File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -616,13 +616,14 @@ def single_model_finetune(
616616 frz_model = torch .jit .load (init_frz_model , map_location = DEVICE )
617617 try :
618618 self .model .load_state_dict (frz_model .state_dict ())
619- except RuntimeError as e :
620- if "Missing key(s) in state_dict" in str (e ):
619+ except RuntimeError as err_msg :
620+ if "Missing key(s) in state_dict" in str (
621+ err_msg
622+ ) or "Unexpected key(s) in state_dict" in str (err_msg ):
621623 self .model .load_state_dict (frz_model .state_dict (), strict = False )
622- log .warning ("Use strict=False to ignore non-matching keys." )
623- log .warning (f"Model state_dict mismatch detected: { e } " )
624+ log .warning ("Loaded with strict=False to ignore non-matching keys." )
624625 else :
625- raise e
626+ raise
626627
627628 # Get model prob for multi-task
628629 if self .multi_task :
You can’t perform that action at this time.
0 commit comments