Skip to content

Commit 49c185c

Browse files
committed
fix: fixed initialization of tied weights in Llama3Initializer
1 parent e88b8aa commit 49c185c

1 file changed

Lines changed: 17 additions & 2 deletions

File tree

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class Llama3InitializerConfig(BaseModel):
1616
num_layers: Annotated[int, Field(strict=True, gt=0)]
1717
n_embd: Annotated[int, Field(strict=True, gt=0)]
18-
use_weight_tying: bool
18+
use_weight_tying: bool = False
1919
depth_init: bool = True
2020

2121

@@ -89,7 +89,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
8989
),
9090
}
9191
if not use_weight_tying:
92-
# lm head weights
92+
# lm head weights (separate output projection matrix)
9393
self.regex_to_init[r"transformer\.lm_head\.weight"] = (
9494
trunc_normal_,
9595
{
@@ -99,6 +99,21 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
9999
"b": 3 / math.sqrt(n_embd),
100100
},
101101
)
102+
else:
103+
# With weight tying, transformer.wte.weight IS the output projection
104+
# (lm_head shares the same tensor), so it must be initialized with the
105+
# small output std (1/sqrt(n_embd)) instead of the embedding std of 1.
106+
# Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too
107+
# large at init, causing the initial loss/grad norm to explode.
108+
self.regex_to_init[r"transformer\.wte\.weight"] = (
109+
trunc_normal_,
110+
{
111+
"mean": 0.0,
112+
"std": 1 / math.sqrt(n_embd),
113+
"a": -3 / math.sqrt(n_embd),
114+
"b": 3 / math.sqrt(n_embd),
115+
},
116+
)
102117

103118
def initialize_in_place(self, model: nn.Module):
104119
self._init_by_fqn_regex(model, self.regex_to_init)

0 commit comments

Comments
 (0)