Skip to content

Commit 4e2a191

Browse files
authored
Merge pull request #456 from Modalities/3B_training_prep
Weight tying improvements
2 parents 8db7d24 + 392fe39 commit 4e2a191

4 files changed

Lines changed: 132 additions & 37 deletions

File tree

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,9 +1123,13 @@ def __init__(
11231123

11241124
@property
11251125
def has_tied_word_embeddings(self) -> bool:
1126-
token_embedding_weight = getattr(self.transformer.wte, "weight", None)
1127-
lm_head_weight = getattr(self.transformer.lm_head, "weight", None)
1128-
return token_embedding_weight is not None and token_embedding_weight is lm_head_weight
1126+
# In pipeline parallelism a stage's transformer may not contain the wte/lm_head submodules
1127+
# (e.g. a middle stage has neither). Such a stage has no tying to report, so return False when
1128+
# either submodule is absent. Whether tied embeddings are allowed at all (they are not, for PP)
1129+
# is enforced separately by the pipeline/TP config validators on the whole, unsplit model.
1130+
if "wte" not in self.transformer or "lm_head" not in self.transformer:
1131+
return False
1132+
return self.transformer.wte.weight is self.transformer.lm_head.weight
11291133

11301134
@overload
11311135
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77
from pydantic import BaseModel, Field
88

9+
from modalities.models.gpt2.gpt2_model import GPT2LLM
910
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
1011
from modalities.utils.logger_utils import get_logger
1112

@@ -15,7 +16,6 @@
1516
class Llama3InitializerConfig(BaseModel):
1617
num_layers: Annotated[int, Field(strict=True, gt=0)]
1718
n_embd: Annotated[int, Field(strict=True, gt=0)]
18-
use_weight_tying: bool
1919
depth_init: bool = True
2020

2121

@@ -24,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF):
2424
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
2525
"""
2626

27-
def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None:
27+
def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
2828
"""
2929
Initializes the Llama3Initializer.
3030
Args:
@@ -35,11 +35,12 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
3535
used for all layers baed on num_layers.
3636
"""
3737
super().__init__()
38+
self.num_layers = num_layers
39+
self.n_embd = n_embd
3840
self.depth_init = depth_init
3941

40-
self.regex_to_init = {
41-
# embedding weights
42-
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
42+
def _build_regex_to_init(self, use_weight_tying: bool) -> dict[str, tuple[Callable, dict]]:
43+
regex_to_init: dict[str, tuple[Callable, dict]] = {
4344
# qkv projections
4445
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
4546
trunc_normal_,
@@ -57,8 +58,8 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
5758
"mean": 0.0,
5859
"std": (
5960
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
60-
if depth_init
61-
else 0.02 / math.sqrt(2 * num_layers)
61+
if self.depth_init
62+
else 0.02 / math.sqrt(2 * self.num_layers)
6263
),
6364
"a": -2,
6465
"b": 2,
@@ -80,28 +81,50 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
8081
"mean": 0.0,
8182
"std": (
8283
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
83-
if depth_init
84-
else 0.02 / math.sqrt(2 * num_layers)
84+
if self.depth_init
85+
else 0.02 / math.sqrt(2 * self.num_layers)
8586
),
8687
"a": -2,
8788
"b": 2,
8889
},
8990
),
9091
}
91-
if not use_weight_tying:
92-
# lm head weights
93-
self.regex_to_init[r"transformer\.lm_head\.weight"] = (
94-
trunc_normal_,
95-
{
96-
"mean": 0.0,
97-
"std": 1 / math.sqrt(n_embd),
98-
"a": -3 / math.sqrt(n_embd),
99-
"b": 3 / math.sqrt(n_embd),
100-
},
101-
)
92+
93+
# Initialization of the output projection (the matrix that produces the logits): small std
94+
# 1/sqrt(n_embd) so the logits are well-scaled at init.
95+
output_projection_init = (
96+
trunc_normal_,
97+
{
98+
"mean": 0.0,
99+
"std": 1 / math.sqrt(self.n_embd),
100+
"a": -3 / math.sqrt(self.n_embd),
101+
"b": 3 / math.sqrt(self.n_embd),
102+
},
103+
)
104+
if use_weight_tying:
105+
# With weight tying, transformer.wte.weight IS the output projection (lm_head shares the
106+
# same tensor), so it must use the small output std instead of the embedding std of 1.
107+
# Otherwise the tied matrix produces logits ~sqrt(n_embd)x too large at init, causing the
108+
# initial loss/grad norm to explode.
109+
regex_to_init[r"transformer\.wte\.weight"] = output_projection_init
110+
else:
111+
# Untied: wte is the embedding (std=1) and lm_head is the separate output projection.
112+
regex_to_init[r"transformer\.wte\.weight"] = (nn.init.normal_, {"mean": 0.0, "std": 1})
113+
regex_to_init[r"transformer\.lm_head\.weight"] = output_projection_init
114+
return regex_to_init
102115

103116
def initialize_in_place(self, model: nn.Module):
104-
self._init_by_fqn_regex(model, self.regex_to_init)
117+
# The FQN regexes are specific to GPT2LLM, which is also the single source of truth for whether
118+
# the word embeddings are tied -- so we infer tying from the model rather than tracking a
119+
# separate flag that could disagree with it (wrong-std tied output projection / uninitialized
120+
# lm_head). Reject model types we cannot initialize.
121+
if not isinstance(model, GPT2LLM):
122+
raise TypeError(
123+
f"Llama3Initializer only supports GPT2LLM (its FQN regexes are specific to it), "
124+
f"but received {type(model).__name__}."
125+
)
126+
regex_to_init = self._build_regex_to_init(use_weight_tying=model.has_tied_word_embeddings)
127+
self._init_by_fqn_regex(model, regex_to_init)
105128

106129
@staticmethod
107130
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]):

tests/test_weight_tying.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import math
2+
13
import pytest
4+
import torch
25
import torch.nn as nn
36
from pydantic import ValidationError
47
from torch.distributed.device_mesh import DeviceMesh
58

69
from modalities.config.config import GPT2ModelTPConfig
7-
from modalities.models.components.layer_norms import LayerNormConfig
10+
from modalities.models.components.layer_norms import LayerNormConfig, PytorchRMSLayerNormConfig
811
from modalities.models.gpt2.gpt2_model import (
912
GPT2LLM,
1013
AttentionConfig,
@@ -13,6 +16,7 @@
1316
LayerNormWrapperConfig,
1417
PositionTypes,
1518
)
19+
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer
1620
from modalities.models.model import ActivationType
1721
from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig
1822
from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator
@@ -27,7 +31,12 @@ def count_parameters(model: nn.Module) -> int:
2731
return sum(p.numel() for p in model.parameters())
2832

2933

30-
def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
34+
def create_gpt2_model(
35+
use_weight_tying: bool,
36+
activation_type: ActivationType = ActivationType.GELU,
37+
bias: bool = True,
38+
norm_type: LayerNorms = LayerNorms.layer_norm,
39+
) -> GPT2LLM:
3140
vocab_size = VOCAB_SIZE
3241
n_embd = EMBEDDING_DIM
3342
sequence_length = 128
@@ -36,9 +45,7 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
3645
n_head_kv = 2
3746
ffn_hidden = 256
3847
dropout = 0.1
39-
bias = True
4048
poe_type = PositionTypes.NOPE
41-
activation_type = ActivationType.GELU
4249
attention_implementation = AttentionImplementation.PYTORCH_FLASH
4350
attention_config = AttentionConfig(
4451
qkv_transforms=[
@@ -53,15 +60,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
5360
)
5461
]
5562
)
56-
attention_norm_config = LayerNormWrapperConfig(
57-
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
58-
)
59-
ffn_norm_config = LayerNormWrapperConfig(
60-
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
61-
)
62-
lm_head_norm_config = LayerNormWrapperConfig(
63-
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
64-
)
63+
64+
def _make_norm_config() -> LayerNormWrapperConfig:
65+
if norm_type == LayerNorms.pytorch_rms_norm:
66+
return LayerNormWrapperConfig(
67+
norm_type=norm_type, config=PytorchRMSLayerNormConfig(normalized_shape=n_embd)
68+
)
69+
return LayerNormWrapperConfig(norm_type=norm_type, config=LayerNormConfig(normalized_shape=n_embd))
70+
71+
attention_norm_config = _make_norm_config()
72+
ffn_norm_config = _make_norm_config()
73+
lm_head_norm_config = _make_norm_config()
6574

6675
return GPT2LLM(
6776
sample_key="input_ids",
@@ -140,6 +149,17 @@ def test_has_tied_word_embeddings_requires_model_capability():
140149
has_tied_word_embeddings(nn.Linear(1, 1))
141150

142151

152+
@pytest.mark.parametrize("module_name", ["wte", "lm_head"])
153+
def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str):
154+
# In pipeline parallelism a stage's transformer ModuleDict only contains the submodules assigned
155+
# to that stage (the transformer container itself is always present), so a stage may lack wte
156+
# and/or lm_head. Such a stage has no tying to report and must not raise.
157+
model = create_gpt2_model(use_weight_tying=True)
158+
del model.transformer[module_name]
159+
160+
assert has_tied_word_embeddings(model) is False
161+
162+
143163
def test_tp_config_rejects_tied_word_embeddings():
144164
model = create_gpt2_model(use_weight_tying=True)
145165
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)
@@ -148,6 +168,54 @@ def test_tp_config_rejects_tied_word_embeddings():
148168
GPT2ModelTPConfig(model=model, device_mesh=device_mesh)
149169

150170

171+
@pytest.mark.parametrize("use_weight_tying", [True, False])
172+
def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool):
173+
"""Regression test for the weight-tying init bug.
174+
175+
With weight tying, ``transformer.wte.weight`` *is* the output projection
176+
(``lm_head`` shares the same tensor), so it must be initialized with the small
177+
output std ``1 / sqrt(n_embd)`` -- not the embedding std of 1. Otherwise the tied
178+
matrix produces logits ~sqrt(n_embd)x too large at init and the loss/grad norm
179+
explode (observed: initial loss ~1685 instead of ~ln(vocab_size)).
180+
"""
181+
n_embd = EMBEDDING_DIM
182+
expected_output_std = 1 / math.sqrt(n_embd)
183+
184+
# SwiGLU + RMSNorm + no bias so the Llama3Initializer's FQN regexes fully match
185+
# the model and it rejects no parameters.
186+
model = create_gpt2_model(
187+
use_weight_tying=use_weight_tying,
188+
activation_type=ActivationType.SWIGLU,
189+
bias=False,
190+
norm_type=LayerNorms.pytorch_rms_norm,
191+
)
192+
# The initializer infers weight tying from the model itself, so no tying flag is passed.
193+
initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True)
194+
# Mirror the production flow (model_factory applies the initializer under no_grad).
195+
with torch.no_grad():
196+
initializer.initialize_in_place(model)
197+
198+
# The logit-producing matrix must be small regardless of weight tying.
199+
output_proj_std = model.transformer.lm_head.weight.detach().float().std().item()
200+
assert output_proj_std == pytest.approx(expected_output_std, rel=0.15)
201+
202+
if use_weight_tying:
203+
# Tied: embedding and output projection are the same (small) tensor.
204+
assert model.transformer.wte.weight is model.transformer.lm_head.weight
205+
else:
206+
# Untied: the embedding keeps the Llama3/TorchTitan std of 1.
207+
embedding_std = model.transformer.wte.weight.detach().float().std().item()
208+
assert embedding_std == pytest.approx(1.0, rel=0.15)
209+
210+
211+
def test_llama3_init_rejects_non_gpt2_model():
212+
# The FQN regexes are GPT2LLM-specific, so the initializer must reject other model types
213+
# rather than silently leaving everything uninitialized.
214+
initializer = Llama3Initializer(num_layers=2, n_embd=EMBEDDING_DIM, depth_init=True)
215+
with pytest.raises(TypeError, match="only supports GPT2LLM"):
216+
initializer.initialize_in_place(nn.Linear(1, 1))
217+
218+
151219
def test_tp_config_allows_untied_word_embeddings():
152220
model = create_gpt2_model(use_weight_tying=False)
153221
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)

tutorials/instruction_tuning/experiments/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)