File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -570,8 +570,30 @@ def collect_single_finetune_params(
570570 state_dict ["_extra_state" ] = self .wrapper .state_dict ()[
571571 "_extra_state"
572572 ]
573-
574- self .wrapper .load_state_dict (state_dict )
573+ try :
574+ self .wrapper .load_state_dict (state_dict )
575+ except RuntimeError as e :
576+ # init from direct fitting
577+ rm_list = []
578+ for kk in state_dict :
579+ # delete direct heads
580+ if (
581+ "fitting_net.force_embed." in kk
582+ or "fitting_net.noise_embed" in kk
583+ ):
584+ rm_list .append (kk )
585+ for kk in rm_list :
586+ state_dict .pop (kk )
587+ state_dict ["_extra_state" ] = self .wrapper .state_dict ()[
588+ "_extra_state"
589+ ]
590+ out_shape_list = [
591+ "model.Default.atomic_model.out_bias" ,
592+ "model.Default.atomic_model.out_std" ,
593+ ]
594+ for kk in out_shape_list :
595+ state_dict [kk ] = state_dict [kk ][:1 , :, :1 ]
596+ self .wrapper .load_state_dict (state_dict )
575597
576598 # change bias for fine-tuning
577599 if finetune_model is not None :
You can’t perform that action at this time.
0 commit comments