File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -541,8 +541,30 @@ def collect_single_finetune_params(
541541 state_dict ["_extra_state" ] = self .wrapper .state_dict ()[
542542 "_extra_state"
543543 ]
544-
545- self .wrapper .load_state_dict (state_dict )
544+ try :
545+ self .wrapper .load_state_dict (state_dict )
546+ except RuntimeError as e :
547+ # init from direct fitting
548+ rm_list = []
549+ for kk in state_dict :
550+ # delete direct heads
551+ if (
552+ "fitting_net.force_embed." in kk
553+ or "fitting_net.noise_embed" in kk
554+ ):
555+ rm_list .append (kk )
556+ for kk in rm_list :
557+ state_dict .pop (kk )
558+ state_dict ["_extra_state" ] = self .wrapper .state_dict ()[
559+ "_extra_state"
560+ ]
561+ out_shape_list = [
562+ "model.Default.atomic_model.out_bias" ,
563+ "model.Default.atomic_model.out_std" ,
564+ ]
565+ for kk in out_shape_list :
566+ state_dict [kk ] = state_dict [kk ][:1 , :, :1 ]
567+ self .wrapper .load_state_dict (state_dict )
546568
547569 # change bias for fine-tuning
548570 if finetune_model is not None :
You can’t perform that action at this time.
0 commit comments