@@ -2387,11 +2387,6 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
23872387
23882388 tied_keys = list (tied_keys .items ())
23892389 for i , (target_param_name , source_param_name ) in enumerate (tied_keys ):
2390- # Usually we tie a single target to a single source, but when both are missing we may later tie
2391- # both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
2392- # a list here
2393- target_param_names = [target_param_name ]
2394-
23952390 # This is `from_pretrained` -> let's check symmetrically in case the source key is not present
23962391 if missing_keys is not None :
23972392 remove_from_missing = True
@@ -2412,7 +2407,6 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
24122407 # We're missing the source but we have the target -> we swap them, tying the parameter that exists
24132408 elif not source_is_there and target_is_there :
24142409 target_param_name , source_param_name = source_param_name , target_param_name
2415- target_param_names = [target_param_name ]
24162410 # Both are missing -> check other keys in case more than 2 keys are tied to the same weight
24172411 elif not source_is_there and not target_is_there :
24182412 for target_backup , source_backup in tied_keys [i + 1 :]:
@@ -2421,10 +2415,10 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
24212415 if source_backup == source_param_name :
24222416 target_backup_is_there = target_backup not in missing_keys
24232417 # If the target is present, we found the correct weight to tie into (we know the source is missing)
2418+ # Note here that we do not tie the missing source right now as well, as it will be done anyway when
2419+ # the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
24242420 if target_backup_is_there :
24252421 source_param_name = target_backup
2426- # Append the source as well, since both are missing we'll tie both
2427- target_param_names .append (source_param_name )
24282422 break
24292423 # If we did not break from the loop, it was impossible to find a source key -> let's raise
24302424 else :
@@ -2440,19 +2434,18 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
24402434
24412435 # Perform the actual tying
24422436 source_param = self .get_parameter_or_buffer (source_param_name )
2443- for target_param_name in target_param_names :
2444- if "." in target_param_name :
2445- parent_name , name = target_param_name .rsplit ("." , 1 )
2446- parent = self .get_submodule (parent_name )
2447- else :
2448- name = target_param_name
2449- parent = self
2450- # Tie the weights
2451- setattr (parent , name , source_param )
2452- self ._adjust_bias (parent , source_param )
2453- # Remove from missing if necesary
2454- if missing_keys is not None and remove_from_missing :
2455- missing_keys .discard (target_param_name )
2437+ if "." in target_param_name :
2438+ parent_name , name = target_param_name .rsplit ("." , 1 )
2439+ parent = self .get_submodule (parent_name )
2440+ else :
2441+ name = target_param_name
2442+ parent = self
2443+ # Tie the weights
2444+ setattr (parent , name , source_param )
2445+ self ._adjust_bias (parent , source_param )
2446+ # Remove from missing if necesary
2447+ if missing_keys is not None and remove_from_missing :
2448+ missing_keys .discard (target_param_name )
24562449
24572450 def _adjust_bias (self , output_embeddings , input_embeddings ):
24582451 if getattr (output_embeddings , "bias" , None ) is not None and hasattr (output_embeddings , "weight" ):
0 commit comments