1515class Llama3InitializerConfig (BaseModel ):
1616 num_layers : Annotated [int , Field (strict = True , gt = 0 )]
1717 n_embd : Annotated [int , Field (strict = True , gt = 0 )]
18- use_weight_tying : bool
18+ use_weight_tying : bool = False
1919 depth_init : bool = True
2020
2121
@@ -89,7 +89,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
8989 ),
9090 }
9191 if not use_weight_tying :
92- # lm head weights
92+ # lm head weights (separate output projection matrix)
9393 self .regex_to_init [r"transformer\.lm_head\.weight" ] = (
9494 trunc_normal_ ,
9595 {
@@ -99,6 +99,21 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
9999 "b" : 3 / math .sqrt (n_embd ),
100100 },
101101 )
102+ else :
103+ # With weight tying, transformer.wte.weight IS the output projection
104+ # (lm_head shares the same tensor), so it must be initialized with the
105+ # small output std (1/sqrt(n_embd)) instead of the embedding std of 1.
106+ # Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too
107+ # large at init, causing the initial loss/grad norm to explode.
108+ self .regex_to_init [r"transformer\.wte\.weight" ] = (
109+ trunc_normal_ ,
110+ {
111+ "mean" : 0.0 ,
112+ "std" : 1 / math .sqrt (n_embd ),
113+ "a" : - 3 / math .sqrt (n_embd ),
114+ "b" : 3 / math .sqrt (n_embd ),
115+ },
116+ )
102117
103118 def initialize_in_place (self , model : nn .Module ):
104119 self ._init_by_fqn_regex (model , self .regex_to_init )
0 commit comments