1+ import math
2+
13import pytest
4+ import torch
25import torch .nn as nn
36from pydantic import ValidationError
47from torch .distributed .device_mesh import DeviceMesh
58
69from modalities .config .config import GPT2ModelTPConfig
7- from modalities .models .components .layer_norms import LayerNormConfig
10+ from modalities .models .components .layer_norms import LayerNormConfig , PytorchRMSLayerNormConfig
811from modalities .models .gpt2 .gpt2_model import (
912 GPT2LLM ,
1013 AttentionConfig ,
1316 LayerNormWrapperConfig ,
1417 PositionTypes ,
1518)
19+ from modalities .models .gpt2 .llama3_like_initialization import Llama3Initializer
1620from modalities .models .model import ActivationType
1721from modalities .models .parallelism .pipeline_parallelism_configs import StagedPipelineConfig
1822from modalities .models .parallelism .stages_generator import GPT2LLMStagesGenerator
@@ -27,7 +31,12 @@ def count_parameters(model: nn.Module) -> int:
2731 return sum (p .numel () for p in model .parameters ())
2832
2933
30- def create_gpt2_model (use_weight_tying : bool ) -> GPT2LLM :
34+ def create_gpt2_model (
35+ use_weight_tying : bool ,
36+ activation_type : ActivationType = ActivationType .GELU ,
37+ bias : bool = True ,
38+ norm_type : LayerNorms = LayerNorms .layer_norm ,
39+ ) -> GPT2LLM :
3140 vocab_size = VOCAB_SIZE
3241 n_embd = EMBEDDING_DIM
3342 sequence_length = 128
@@ -36,9 +45,7 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
3645 n_head_kv = 2
3746 ffn_hidden = 256
3847 dropout = 0.1
39- bias = True
4048 poe_type = PositionTypes .NOPE
41- activation_type = ActivationType .GELU
4249 attention_implementation = AttentionImplementation .PYTORCH_FLASH
4350 attention_config = AttentionConfig (
4451 qkv_transforms = [
@@ -53,15 +60,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
5360 )
5461 ]
5562 )
56- attention_norm_config = LayerNormWrapperConfig (
57- norm_type = LayerNorms .layer_norm , config = LayerNormConfig (normalized_shape = n_embd )
58- )
59- ffn_norm_config = LayerNormWrapperConfig (
60- norm_type = LayerNorms .layer_norm , config = LayerNormConfig (normalized_shape = n_embd )
61- )
62- lm_head_norm_config = LayerNormWrapperConfig (
63- norm_type = LayerNorms .layer_norm , config = LayerNormConfig (normalized_shape = n_embd )
64- )
63+
64+ def _make_norm_config () -> LayerNormWrapperConfig :
65+ if norm_type == LayerNorms .pytorch_rms_norm :
66+ return LayerNormWrapperConfig (
67+ norm_type = norm_type , config = PytorchRMSLayerNormConfig (normalized_shape = n_embd )
68+ )
69+ return LayerNormWrapperConfig (norm_type = norm_type , config = LayerNormConfig (normalized_shape = n_embd ))
70+
71+ attention_norm_config = _make_norm_config ()
72+ ffn_norm_config = _make_norm_config ()
73+ lm_head_norm_config = _make_norm_config ()
6574
6675 return GPT2LLM (
6776 sample_key = "input_ids" ,
@@ -140,6 +149,17 @@ def test_has_tied_word_embeddings_requires_model_capability():
140149 has_tied_word_embeddings (nn .Linear (1 , 1 ))
141150
142151
152+ @pytest .mark .parametrize ("module_name" , ["wte" , "lm_head" ])
153+ def test_has_tied_word_embeddings_handles_pipeline_stage (module_name : str ):
154+ # In pipeline parallelism a stage's transformer ModuleDict only contains the submodules assigned
155+ # to that stage (the transformer container itself is always present), so a stage may lack wte
156+ # and/or lm_head. Such a stage has no tying to report and must not raise.
157+ model = create_gpt2_model (use_weight_tying = True )
158+ del model .transformer [module_name ]
159+
160+ assert has_tied_word_embeddings (model ) is False
161+
162+
143163def test_tp_config_rejects_tied_word_embeddings ():
144164 model = create_gpt2_model (use_weight_tying = True )
145165 device_mesh = create_device_mesh_stub (ParallelismDegrees .TP .value )
@@ -148,6 +168,54 @@ def test_tp_config_rejects_tied_word_embeddings():
148168 GPT2ModelTPConfig (model = model , device_mesh = device_mesh )
149169
150170
171+ @pytest .mark .parametrize ("use_weight_tying" , [True , False ])
172+ def test_llama3_init_keeps_output_projection_small (use_weight_tying : bool ):
173+ """Regression test for the weight-tying init bug.
174+
175+ With weight tying, ``transformer.wte.weight`` *is* the output projection
176+ (``lm_head`` shares the same tensor), so it must be initialized with the small
177+ output std ``1 / sqrt(n_embd)`` -- not the embedding std of 1. Otherwise the tied
178+ matrix produces logits ~sqrt(n_embd)x too large at init and the loss/grad norm
179+ explode (observed: initial loss ~1685 instead of ~ln(vocab_size)).
180+ """
181+ n_embd = EMBEDDING_DIM
182+ expected_output_std = 1 / math .sqrt (n_embd )
183+
184+ # SwiGLU + RMSNorm + no bias so the Llama3Initializer's FQN regexes fully match
185+ # the model and it rejects no parameters.
186+ model = create_gpt2_model (
187+ use_weight_tying = use_weight_tying ,
188+ activation_type = ActivationType .SWIGLU ,
189+ bias = False ,
190+ norm_type = LayerNorms .pytorch_rms_norm ,
191+ )
192+ # The initializer infers weight tying from the model itself, so no tying flag is passed.
193+ initializer = Llama3Initializer (num_layers = 2 , n_embd = n_embd , depth_init = True )
194+ # Mirror the production flow (model_factory applies the initializer under no_grad).
195+ with torch .no_grad ():
196+ initializer .initialize_in_place (model )
197+
198+ # The logit-producing matrix must be small regardless of weight tying.
199+ output_proj_std = model .transformer .lm_head .weight .detach ().float ().std ().item ()
200+ assert output_proj_std == pytest .approx (expected_output_std , rel = 0.15 )
201+
202+ if use_weight_tying :
203+ # Tied: embedding and output projection are the same (small) tensor.
204+ assert model .transformer .wte .weight is model .transformer .lm_head .weight
205+ else :
206+ # Untied: the embedding keeps the Llama3/TorchTitan std of 1.
207+ embedding_std = model .transformer .wte .weight .detach ().float ().std ().item ()
208+ assert embedding_std == pytest .approx (1.0 , rel = 0.15 )
209+
210+
211+ def test_llama3_init_rejects_non_gpt2_model ():
212+ # The FQN regexes are GPT2LLM-specific, so the initializer must reject other model types
213+ # rather than silently leaving everything uninitialized.
214+ initializer = Llama3Initializer (num_layers = 2 , n_embd = EMBEDDING_DIM , depth_init = True )
215+ with pytest .raises (TypeError , match = "only supports GPT2LLM" ):
216+ initializer .initialize_in_place (nn .Linear (1 , 1 ))
217+
218+
151219def test_tp_config_allows_untied_word_embeddings ():
152220 model = create_gpt2_model (use_weight_tying = False )
153221 device_mesh = create_device_mesh_stub (ParallelismDegrees .TP .value )
0 commit comments