Skip to content

Commit 09b348a

Browse files
committed
fix(merger): re-tie weights to avoid duplicating tied parameters
FSDP saves tied parameters (e.g. lm_head <-> embed_tokens) as independent shards. After load_state_dict(..., assign=True) they become separate tensors and save_pretrained writes both, bloating the merged checkpoint. Re-tie when the model declares tying and the saved tensors agree, otherwise warn and skip.
1 parent ae4513b commit 09b348a

1 file changed

Lines changed: 31 additions & 0 deletions

File tree

src/lmms_engine/merger/fsdp2.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,36 @@ def _resolve_checkpoint_path(self, path: Path) -> Path:
146146
latest_checkpoint = checkpoint_folders[-1]
147147
return latest_checkpoint
148148

149+
def maybe_tie_weights(self, model: torch.nn.Module, config: object, state_dict: dict) -> None:
150+
"""Re-tie weights if the model declares weight tying.
151+
152+
FSDP saves tied parameters (e.g. ``lm_head`` <-> ``embed_tokens``) as
153+
independent shards, so after ``load_state_dict(..., assign=True)`` they
154+
become separate tensors and ``save_pretrained`` would write both.
155+
156+
Only re-ties when the model declares tying AND the saved tensors
157+
actually agree, to avoid silently dropping divergent weights.
158+
"""
159+
tied_keys_map = getattr(model, "_tied_weights_keys", None)
160+
tie_word_embeddings = getattr(config, "tie_word_embeddings", False) or getattr(
161+
getattr(config, "text_config", None), "tie_word_embeddings", False
162+
)
163+
if not (tied_keys_map and tie_word_embeddings):
164+
return
165+
166+
if isinstance(tied_keys_map, dict):
167+
for tied_key, source_key in tied_keys_map.items():
168+
t1 = state_dict.get(tied_key)
169+
t2 = state_dict.get(source_key)
170+
if t1 is not None and t2 is not None and not torch.equal(t1, t2):
171+
logger.warning(
172+
f"Tied weights mismatch: '{tied_key}' != '{source_key}'. Skipping tie_weights()."
173+
)
174+
return
175+
176+
logger.info("Re-tying weights (tie_word_embeddings=True).")
177+
model.tie_weights()
178+
149179
def merge(
150180
self,
151181
checkpoint_path: Path,
@@ -194,6 +224,7 @@ def merge(
194224
with init_empty_weights():
195225
model = model_cls.from_config(config)
196226
model.load_state_dict(full_state_dict, assign=True)
227+
self.maybe_tie_weights(model, config, full_state_dict)
197228
processor = AutoProcessor.from_pretrained(checkpoint_path)
198229
processor.save_pretrained(output_path)
199230
config.save_pretrained(output_path)

0 commit comments

Comments
 (0)