Skip to content

Commit 392fe39

Browse files
committed
fix: update Llama3Initializer to infer weight tying from model and reject non-GPT2 models
1 parent dca6cc1 commit 392fe39

2 files changed

Lines changed: 53 additions & 36 deletions

File tree

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77
from pydantic import BaseModel, Field
88

9+
from modalities.models.gpt2.gpt2_model import GPT2LLM
910
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
1011
from modalities.utils.logger_utils import get_logger
1112

@@ -15,7 +16,6 @@
1516
class Llama3InitializerConfig(BaseModel):
1617
num_layers: Annotated[int, Field(strict=True, gt=0)]
1718
n_embd: Annotated[int, Field(strict=True, gt=0)]
18-
use_weight_tying: bool = False
1919
depth_init: bool = True
2020

2121

@@ -24,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF):
2424
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
2525
"""
2626

27-
def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None:
27+
def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
2828
"""
2929
Initializes the Llama3Initializer.
3030
Args:
@@ -35,11 +35,12 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
3535
used for all layers baed on num_layers.
3636
"""
3737
super().__init__()
38+
self.num_layers = num_layers
39+
self.n_embd = n_embd
3840
self.depth_init = depth_init
3941

40-
self.regex_to_init = {
41-
# embedding weights
42-
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
42+
def _build_regex_to_init(self, use_weight_tying: bool) -> dict[str, tuple[Callable, dict]]:
43+
regex_to_init: dict[str, tuple[Callable, dict]] = {
4344
# qkv projections
4445
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
4546
trunc_normal_,
@@ -57,8 +58,8 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
5758
"mean": 0.0,
5859
"std": (
5960
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
60-
if depth_init
61-
else 0.02 / math.sqrt(2 * num_layers)
61+
if self.depth_init
62+
else 0.02 / math.sqrt(2 * self.num_layers)
6263
),
6364
"a": -2,
6465
"b": 2,
@@ -80,43 +81,50 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
8081
"mean": 0.0,
8182
"std": (
8283
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
83-
if depth_init
84-
else 0.02 / math.sqrt(2 * num_layers)
84+
if self.depth_init
85+
else 0.02 / math.sqrt(2 * self.num_layers)
8586
),
8687
"a": -2,
8788
"b": 2,
8889
},
8990
),
9091
}
91-
if not use_weight_tying:
92-
# lm head weights (separate output projection matrix)
93-
self.regex_to_init[r"transformer\.lm_head\.weight"] = (
94-
trunc_normal_,
95-
{
96-
"mean": 0.0,
97-
"std": 1 / math.sqrt(n_embd),
98-
"a": -3 / math.sqrt(n_embd),
99-
"b": 3 / math.sqrt(n_embd),
100-
},
101-
)
92+
93+
# Initialization of the output projection (the matrix that produces the logits): small std
94+
# 1/sqrt(n_embd) so the logits are well-scaled at init.
95+
output_projection_init = (
96+
trunc_normal_,
97+
{
98+
"mean": 0.0,
99+
"std": 1 / math.sqrt(self.n_embd),
100+
"a": -3 / math.sqrt(self.n_embd),
101+
"b": 3 / math.sqrt(self.n_embd),
102+
},
103+
)
104+
if use_weight_tying:
105+
# With weight tying, transformer.wte.weight IS the output projection (lm_head shares the
106+
# same tensor), so it must use the small output std instead of the embedding std of 1.
107+
# Otherwise the tied matrix produces logits ~sqrt(n_embd)x too large at init, causing the
108+
# initial loss/grad norm to explode.
109+
regex_to_init[r"transformer\.wte\.weight"] = output_projection_init
102110
else:
103-
# With weight tying, transformer.wte.weight IS the output projection
104-
# (lm_head shares the same tensor), so it must be initialized with the
105-
# small output std (1/sqrt(n_embd)) instead of the embedding std of 1.
106-
# Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too
107-
# large at init, causing the initial loss/grad norm to explode.
108-
self.regex_to_init[r"transformer\.wte\.weight"] = (
109-
trunc_normal_,
110-
{
111-
"mean": 0.0,
112-
"std": 1 / math.sqrt(n_embd),
113-
"a": -3 / math.sqrt(n_embd),
114-
"b": 3 / math.sqrt(n_embd),
115-
},
116-
)
111+
# Untied: wte is the embedding (std=1) and lm_head is the separate output projection.
112+
regex_to_init[r"transformer\.wte\.weight"] = (nn.init.normal_, {"mean": 0.0, "std": 1})
113+
regex_to_init[r"transformer\.lm_head\.weight"] = output_projection_init
114+
return regex_to_init
117115

118116
def initialize_in_place(self, model: nn.Module):
119-
self._init_by_fqn_regex(model, self.regex_to_init)
117+
# The FQN regexes are specific to GPT2LLM, which is also the single source of truth for whether
118+
# the word embeddings are tied -- so we infer tying from the model rather than tracking a
119+
# separate flag that could disagree with it (wrong-std tied output projection / uninitialized
120+
# lm_head). Reject model types we cannot initialize.
121+
if not isinstance(model, GPT2LLM):
122+
raise TypeError(
123+
f"Llama3Initializer only supports GPT2LLM (its FQN regexes are specific to it), "
124+
f"but received {type(model).__name__}."
125+
)
126+
regex_to_init = self._build_regex_to_init(use_weight_tying=model.has_tied_word_embeddings)
127+
self._init_by_fqn_regex(model, regex_to_init)
120128

121129
@staticmethod
122130
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]):

tests/test_weight_tying.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool):
189189
bias=False,
190190
norm_type=LayerNorms.pytorch_rms_norm,
191191
)
192-
initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True, use_weight_tying=use_weight_tying)
192+
# The initializer infers weight tying from the model itself, so no tying flag is passed.
193+
initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True)
193194
# Mirror the production flow (model_factory applies the initializer under no_grad).
194195
with torch.no_grad():
195196
initializer.initialize_in_place(model)
@@ -207,6 +208,14 @@ def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool):
207208
assert embedding_std == pytest.approx(1.0, rel=0.15)
208209

209210

211+
def test_llama3_init_rejects_non_gpt2_model():
212+
# The FQN regexes are GPT2LLM-specific, so the initializer must reject other model types
213+
# rather than silently leaving everything uninitialized.
214+
initializer = Llama3Initializer(num_layers=2, n_embd=EMBEDDING_DIM, depth_init=True)
215+
with pytest.raises(TypeError, match="only supports GPT2LLM"):
216+
initializer.initialize_in_place(nn.Linear(1, 1))
217+
218+
210219
def test_tp_config_allows_untied_word_embeddings():
211220
model = create_gpt2_model(use_weight_tying=False)
212221
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)

0 commit comments

Comments
 (0)