66import torch .nn as nn
77from pydantic import BaseModel , Field
88
9+ from modalities .models .gpt2 .gpt2_model import GPT2LLM
910from modalities .nn .model_initialization .initialization_if import ModelInitializationIF
1011from modalities .utils .logger_utils import get_logger
1112
1516class Llama3InitializerConfig (BaseModel ):
1617 num_layers : Annotated [int , Field (strict = True , gt = 0 )]
1718 n_embd : Annotated [int , Field (strict = True , gt = 0 )]
18- use_weight_tying : bool = False
1919 depth_init : bool = True
2020
2121
@@ -24,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF):
2424 Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
2525 """
2626
27- def __init__ (self , num_layers : int , n_embd : int , depth_init : bool , use_weight_tying : bool ) -> None :
27+ def __init__ (self , num_layers : int , n_embd : int , depth_init : bool ) -> None :
2828 """
2929 Initializes the Llama3Initializer.
3030 Args:
@@ -35,11 +35,12 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
3535 used for all layers baed on num_layers.
3636 """
3737 super ().__init__ ()
38+ self .num_layers = num_layers
39+ self .n_embd = n_embd
3840 self .depth_init = depth_init
3941
40- self .regex_to_init = {
41- # embedding weights
42- r"transformer\.wte\.weight" : (nn .init .normal_ , {"mean" : 0.0 , "std" : 1 }),
42+ def _build_regex_to_init (self , use_weight_tying : bool ) -> dict [str , tuple [Callable , dict ]]:
43+ regex_to_init : dict [str , tuple [Callable , dict ]] = {
4344 # qkv projections
4445 r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight" : (
4546 trunc_normal_ ,
@@ -57,8 +58,8 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
5758 "mean" : 0.0 ,
5859 "std" : (
5960 (lambda layer_id : 0.02 / math .sqrt (2 * (layer_id + 1 )))
60- if depth_init
61- else 0.02 / math .sqrt (2 * num_layers )
61+ if self . depth_init
62+ else 0.02 / math .sqrt (2 * self . num_layers )
6263 ),
6364 "a" : - 2 ,
6465 "b" : 2 ,
@@ -80,43 +81,50 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
8081 "mean" : 0.0 ,
8182 "std" : (
8283 (lambda layer_id : 0.02 / math .sqrt (2 * (layer_id + 1 )))
83- if depth_init
84- else 0.02 / math .sqrt (2 * num_layers )
84+ if self . depth_init
85+ else 0.02 / math .sqrt (2 * self . num_layers )
8586 ),
8687 "a" : - 2 ,
8788 "b" : 2 ,
8889 },
8990 ),
9091 }
91- if not use_weight_tying :
92- # lm head weights (separate output projection matrix)
93- self .regex_to_init [r"transformer\.lm_head\.weight" ] = (
94- trunc_normal_ ,
95- {
96- "mean" : 0.0 ,
97- "std" : 1 / math .sqrt (n_embd ),
98- "a" : - 3 / math .sqrt (n_embd ),
99- "b" : 3 / math .sqrt (n_embd ),
100- },
101- )
92+
93+ # Initialization of the output projection (the matrix that produces the logits): small std
94+ # 1/sqrt(n_embd) so the logits are well-scaled at init.
95+ output_projection_init = (
96+ trunc_normal_ ,
97+ {
98+ "mean" : 0.0 ,
99+ "std" : 1 / math .sqrt (self .n_embd ),
100+ "a" : - 3 / math .sqrt (self .n_embd ),
101+ "b" : 3 / math .sqrt (self .n_embd ),
102+ },
103+ )
104+ if use_weight_tying :
105+ # With weight tying, transformer.wte.weight IS the output projection (lm_head shares the
106+ # same tensor), so it must use the small output std instead of the embedding std of 1.
107+ # Otherwise the tied matrix produces logits ~sqrt(n_embd)x too large at init, causing the
108+ # initial loss/grad norm to explode.
109+ regex_to_init [r"transformer\.wte\.weight" ] = output_projection_init
102110 else :
103- # With weight tying, transformer.wte.weight IS the output projection
104- # (lm_head shares the same tensor), so it must be initialized with the
105- # small output std (1/sqrt(n_embd)) instead of the embedding std of 1.
106- # Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too
107- # large at init, causing the initial loss/grad norm to explode.
108- self .regex_to_init [r"transformer\.wte\.weight" ] = (
109- trunc_normal_ ,
110- {
111- "mean" : 0.0 ,
112- "std" : 1 / math .sqrt (n_embd ),
113- "a" : - 3 / math .sqrt (n_embd ),
114- "b" : 3 / math .sqrt (n_embd ),
115- },
116- )
111+ # Untied: wte is the embedding (std=1) and lm_head is the separate output projection.
112+ regex_to_init [r"transformer\.wte\.weight" ] = (nn .init .normal_ , {"mean" : 0.0 , "std" : 1 })
113+ regex_to_init [r"transformer\.lm_head\.weight" ] = output_projection_init
114+ return regex_to_init
117115
118116 def initialize_in_place (self , model : nn .Module ):
119- self ._init_by_fqn_regex (model , self .regex_to_init )
117+ # The FQN regexes are specific to GPT2LLM, which is also the single source of truth for whether
118+ # the word embeddings are tied -- so we infer tying from the model rather than tracking a
119+ # separate flag that could disagree with it (wrong-std tied output projection / uninitialized
120+ # lm_head). Reject model types we cannot initialize.
121+ if not isinstance (model , GPT2LLM ):
122+ raise TypeError (
123+ f"Llama3Initializer only supports GPT2LLM (its FQN regexes are specific to it), "
124+ f"but received { type (model ).__name__ } ."
125+ )
126+ regex_to_init = self ._build_regex_to_init (use_weight_tying = model .has_tied_word_embeddings )
127+ self ._init_by_fqn_regex (model , regex_to_init )
120128
121129 @staticmethod
122130 def _init_by_fqn_regex (model : nn .Module , regex_to_init : dict [str , tuple [Callable , dict ]]):
0 commit comments