Skip to content

Commit 4d6516e

Browse files
authored
Simplify tie weights logic (huggingface#42895)
* fix * let's not use source backup, clearer to use original name imo * fix * use a set * simplify * style * add comment
1 parent 24b311e commit 4d6516e

1 file changed

Lines changed: 14 additions & 21 deletions

File tree

src/transformers/modeling_utils.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2387,11 +2387,6 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
23872387

23882388
tied_keys = list(tied_keys.items())
23892389
for i, (target_param_name, source_param_name) in enumerate(tied_keys):
2390-
# Usually we tie a single target to a single source, but when both are missing we may later tie
2391-
# both the source and target to a third "backup" parameter that is present in the checkpoint, so we use
2392-
# a list here
2393-
target_param_names = [target_param_name]
2394-
23952390
# This is `from_pretrained` -> let's check symmetrically in case the source key is not present
23962391
if missing_keys is not None:
23972392
remove_from_missing = True
@@ -2412,7 +2407,6 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
24122407
# We're missing the source but we have the target -> we swap them, tying the parameter that exists
24132408
elif not source_is_there and target_is_there:
24142409
target_param_name, source_param_name = source_param_name, target_param_name
2415-
target_param_names = [target_param_name]
24162410
# Both are missing -> check other keys in case more than 2 keys are tied to the same weight
24172411
elif not source_is_there and not target_is_there:
24182412
for target_backup, source_backup in tied_keys[i + 1 :]:
@@ -2421,10 +2415,10 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
24212415
if source_backup == source_param_name:
24222416
target_backup_is_there = target_backup not in missing_keys
24232417
# If the target is present, we found the correct weight to tie into (we know the source is missing)
2418+
# Note here that we do not tie the missing source right now as well, as it will be done anyway when
2419+
# the pair (target_backup, source_backup) becomes the main pair (target_param_name, source_param_name)
24242420
if target_backup_is_there:
24252421
source_param_name = target_backup
2426-
# Append the source as well, since both are missing we'll tie both
2427-
target_param_names.append(source_param_name)
24282422
break
24292423
# If we did not break from the loop, it was impossible to find a source key -> let's raise
24302424
else:
@@ -2440,19 +2434,18 @@ def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping
24402434

24412435
# Perform the actual tying
24422436
source_param = self.get_parameter_or_buffer(source_param_name)
2443-
for target_param_name in target_param_names:
2444-
if "." in target_param_name:
2445-
parent_name, name = target_param_name.rsplit(".", 1)
2446-
parent = self.get_submodule(parent_name)
2447-
else:
2448-
name = target_param_name
2449-
parent = self
2450-
# Tie the weights
2451-
setattr(parent, name, source_param)
2452-
self._adjust_bias(parent, source_param)
2453-
# Remove from missing if necesary
2454-
if missing_keys is not None and remove_from_missing:
2455-
missing_keys.discard(target_param_name)
2437+
if "." in target_param_name:
2438+
parent_name, name = target_param_name.rsplit(".", 1)
2439+
parent = self.get_submodule(parent_name)
2440+
else:
2441+
name = target_param_name
2442+
parent = self
2443+
# Tie the weights
2444+
setattr(parent, name, source_param)
2445+
self._adjust_bias(parent, source_param)
2446+
# Remove from missing if necesary
2447+
if missing_keys is not None and remove_from_missing:
2448+
missing_keys.discard(target_param_name)
24562449

24572450
def _adjust_bias(self, output_embeddings, input_embeddings):
24582451
if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"):

0 commit comments

Comments
 (0)