Skip to content

Commit b928344

Browse files
fix(merger): re-tie weights to avoid duplicating tied parameters (#157)
* fix(merger): re-tie weights to avoid duplicating tied parameters FSDP saves tied parameters (e.g. lm_head <-> embed_tokens) as independent shards. After load_state_dict(..., assign=True) they become separate tensors and save_pretrained writes both, bloating the merged checkpoint. Re-tie when the model declares tying and the saved tensors agree, otherwise warn and skip. * style: auto-fix lint (black + isort) --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ae4513b commit b928344

1 file changed

Lines changed: 32 additions & 1 deletion

File tree

src/lmms_engine/merger/fsdp2.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def consolidate(self, shard_state_dicts: list[dict]) -> dict:
109109
if not isinstance(state_dict[key], list):
110110
continue
111111
# Non-sharded tensors are duplicated across ranks; just take the first one
112-
if all(t.shape == state_dict[key][0].shape and torch.equal(t, state_dict[key][0]) for t in state_dict[key][1:]):
112+
if all(
113+
t.shape == state_dict[key][0].shape and torch.equal(t, state_dict[key][0]) for t in state_dict[key][1:]
114+
):
113115
state_dict[key] = state_dict[key][0]
114116
else:
115117
state_dict[key] = torch.cat(state_dict[key], dim=0)
@@ -146,6 +148,34 @@ def _resolve_checkpoint_path(self, path: Path) -> Path:
146148
latest_checkpoint = checkpoint_folders[-1]
147149
return latest_checkpoint
148150

151+
def maybe_tie_weights(self, model: torch.nn.Module, config: object, state_dict: dict) -> None:
152+
"""Re-tie weights if the model declares weight tying.
153+
154+
FSDP saves tied parameters (e.g. ``lm_head`` <-> ``embed_tokens``) as
155+
independent shards, so after ``load_state_dict(..., assign=True)`` they
156+
become separate tensors and ``save_pretrained`` would write both.
157+
158+
Only re-ties when the model declares tying AND the saved tensors
159+
actually agree, to avoid silently dropping divergent weights.
160+
"""
161+
tied_keys_map = getattr(model, "_tied_weights_keys", None)
162+
tie_word_embeddings = getattr(config, "tie_word_embeddings", False) or getattr(
163+
getattr(config, "text_config", None), "tie_word_embeddings", False
164+
)
165+
if not (tied_keys_map and tie_word_embeddings):
166+
return
167+
168+
if isinstance(tied_keys_map, dict):
169+
for tied_key, source_key in tied_keys_map.items():
170+
t1 = state_dict.get(tied_key)
171+
t2 = state_dict.get(source_key)
172+
if t1 is not None and t2 is not None and not torch.equal(t1, t2):
173+
logger.warning(f"Tied weights mismatch: '{tied_key}' != '{source_key}'. Skipping tie_weights().")
174+
return
175+
176+
logger.info("Re-tying weights (tie_word_embeddings=True).")
177+
model.tie_weights()
178+
149179
def merge(
150180
self,
151181
checkpoint_path: Path,
@@ -194,6 +224,7 @@ def merge(
194224
with init_empty_weights():
195225
model = model_cls.from_config(config)
196226
model.load_state_dict(full_state_dict, assign=True)
227+
self.maybe_tie_weights(model, config, full_state_dict)
197228
processor = AutoProcessor.from_pretrained(checkpoint_path)
198229
processor.save_pretrained(output_path)
199230
config.save_pretrained(output_path)

0 commit comments

Comments
 (0)