File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -420,6 +420,14 @@ def get_lr(lr_params):
420420 if missing_keys :
421421 target_state_dict = self .wrapper .state_dict ()
422422 slim_keys = []
423+ out_shape_list = [
424+ "model.Default.atomic_model.out_bias" ,
425+ "model.Default.atomic_model.out_std" ,
426+ ]
427+ for kk in out_shape_list :
428+ old_stat = state_dict [kk ].clone ().detach ()
429+ state_dict [kk ] = target_state_dict [kk ].clone ().detach ()
430+ state_dict [kk ][:1 , :, :1 ] = old_stat
423431 for item in missing_keys :
424432 state_dict [item ] = target_state_dict [item ].clone ().detach ()
425433 new_key = True
@@ -428,7 +436,7 @@ def get_lr(lr_params):
428436 new_key = False
429437 break
430438 if new_key :
431- tmp_keys = "." .join (item .split ("." )[:3 ])
439+ tmp_keys = "." .join (item .split ("." )[:- 2 ])
432440 slim_keys .append (tmp_keys )
433441 slim_keys = [i + ".*" for i in slim_keys ]
434442 log .warning (
You can’t perform that action at this time.
0 commit comments