Skip to content

Commit 6d4d3ca

Browse files
committed
chore: fix has_tied_word_embeddings for pipeline parallelism
1 parent 0927a2f commit 6d4d3ca

3 files changed

Lines changed: 12 additions & 8 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -940,9 +940,13 @@ def __init__(
940940

941941
@property
942942
def has_tied_word_embeddings(self) -> bool:
943-
token_embedding_weight = getattr(self.transformer.wte, "weight", None)
944-
lm_head_weight = getattr(self.transformer.lm_head, "weight", None)
945-
return token_embedding_weight is not None and token_embedding_weight is lm_head_weight
943+
# In pipeline parallelism a stage's transformer may not contain the wte/lm_head submodules
944+
# (e.g. a middle stage has neither). Such a stage has no tying to report, so return False when
945+
# either submodule is absent. Whether tied embeddings are allowed at all (they are not, for PP)
946+
# is enforced separately by the pipeline/TP config validators on the whole, unsplit model.
947+
if "wte" not in self.transformer or "lm_head" not in self.transformer:
948+
return False
949+
return self.transformer.wte.weight is self.transformer.lm_head.weight
946950

947951
@overload
948952
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

tests/test_weight_tying.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,13 @@ def test_has_tied_word_embeddings_requires_model_capability():
149149
has_tied_word_embeddings(nn.Linear(1, 1))
150150

151151

152-
@pytest.mark.parametrize("module_name", ["transformer", "wte", "lm_head"])
152+
@pytest.mark.parametrize("module_name", ["wte", "lm_head"])
153153
def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str):
154+
# In pipeline parallelism a stage's transformer ModuleDict only contains the submodules assigned
155+
# to that stage (the transformer container itself is always present), so a stage may lack wte
156+
# and/or lm_head. Such a stage has no tying to report and must not raise.
154157
model = create_gpt2_model(use_weight_tying=True)
155-
if module_name == "transformer":
156-
del model.transformer
157-
else:
158-
del model.transformer[module_name]
158+
del model.transformer[module_name]
159159

160160
assert has_tied_word_embeddings(model) is False
161161

tutorials/instruction_tuning/experiments/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)