Skip to content

Commit 0927a2f

Browse files
committed
chore: added weight tying tests and Llama3 initialization checks
1 parent 49c185c commit 0927a2f

2 files changed

Lines changed: 74 additions & 14 deletions

File tree

tests/fsdp2_parallelization/test_tensor_parallelism.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from pathlib import Path
23
from typing import Tuple
34

@@ -27,7 +28,7 @@ def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir:
2728

2829
config_dict["model_raw"]["config"]["activation_type"] = activation_type
2930

30-
tmp_file_path = tmp_dir / original_config_path.name
31+
tmp_file_path = tmp_dir / f"{activation_type}_{os.getpid()}_{original_config_path.name}"
3132
with tmp_file_path.open("w", encoding="utf-8") as f:
3233
yaml.safe_dump(config_dict, f)
3334

tests/test_weight_tying.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import math
2+
13
import pytest
4+
import torch
25
import torch.nn as nn
36
from pydantic import ValidationError
47
from torch.distributed.device_mesh import DeviceMesh
58

69
from modalities.config.config import GPT2ModelTPConfig
7-
from modalities.models.components.layer_norms import LayerNormConfig
10+
from modalities.models.components.layer_norms import LayerNormConfig, PytorchRMSLayerNormConfig
811
from modalities.models.gpt2.gpt2_model import (
912
GPT2LLM,
1013
AttentionConfig,
@@ -13,6 +16,7 @@
1316
LayerNormWrapperConfig,
1417
PositionTypes,
1518
)
19+
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer
1620
from modalities.models.model import ActivationType
1721
from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig
1822
from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator
@@ -27,7 +31,12 @@ def count_parameters(model: nn.Module) -> int:
2731
return sum(p.numel() for p in model.parameters())
2832

2933

30-
def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
34+
def create_gpt2_model(
35+
use_weight_tying: bool,
36+
activation_type: ActivationType = ActivationType.GELU,
37+
bias: bool = True,
38+
norm_type: LayerNorms = LayerNorms.layer_norm,
39+
) -> GPT2LLM:
3140
vocab_size = VOCAB_SIZE
3241
n_embd = EMBEDDING_DIM
3342
sequence_length = 128
@@ -36,9 +45,7 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
3645
n_head_kv = 2
3746
ffn_hidden = 256
3847
dropout = 0.1
39-
bias = True
4048
poe_type = PositionTypes.NOPE
41-
activation_type = ActivationType.GELU
4249
attention_implementation = AttentionImplementation.PYTORCH_FLASH
4350
attention_config = AttentionConfig(
4451
qkv_transforms=[
@@ -53,15 +60,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
5360
)
5461
]
5562
)
56-
attention_norm_config = LayerNormWrapperConfig(
57-
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
58-
)
59-
ffn_norm_config = LayerNormWrapperConfig(
60-
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
61-
)
62-
lm_head_norm_config = LayerNormWrapperConfig(
63-
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
64-
)
63+
64+
def _make_norm_config() -> LayerNormWrapperConfig:
65+
if norm_type == LayerNorms.pytorch_rms_norm:
66+
return LayerNormWrapperConfig(
67+
norm_type=norm_type, config=PytorchRMSLayerNormConfig(normalized_shape=n_embd)
68+
)
69+
return LayerNormWrapperConfig(norm_type=norm_type, config=LayerNormConfig(normalized_shape=n_embd))
70+
71+
attention_norm_config = _make_norm_config()
72+
ffn_norm_config = _make_norm_config()
73+
lm_head_norm_config = _make_norm_config()
6574

6675
return GPT2LLM(
6776
sample_key="input_ids",
@@ -140,6 +149,17 @@ def test_has_tied_word_embeddings_requires_model_capability():
140149
has_tied_word_embeddings(nn.Linear(1, 1))
141150

142151

152+
@pytest.mark.parametrize("module_name", ["transformer", "wte", "lm_head"])
153+
def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str):
154+
model = create_gpt2_model(use_weight_tying=True)
155+
if module_name == "transformer":
156+
del model.transformer
157+
else:
158+
del model.transformer[module_name]
159+
160+
assert has_tied_word_embeddings(model) is False
161+
162+
143163
def test_tp_config_rejects_tied_word_embeddings():
144164
model = create_gpt2_model(use_weight_tying=True)
145165
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)
@@ -148,6 +168,45 @@ def test_tp_config_rejects_tied_word_embeddings():
148168
GPT2ModelTPConfig(model=model, device_mesh=device_mesh)
149169

150170

171+
@pytest.mark.parametrize("use_weight_tying", [True, False])
172+
def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool):
173+
"""Regression test for the weight-tying init bug.
174+
175+
With weight tying, ``transformer.wte.weight`` *is* the output projection
176+
(``lm_head`` shares the same tensor), so it must be initialized with the small
177+
output std ``1 / sqrt(n_embd)`` -- not the embedding std of 1. Otherwise the tied
178+
matrix produces logits ~sqrt(n_embd)x too large at init and the loss/grad norm
179+
explode (observed: initial loss ~1685 instead of ~ln(vocab_size)).
180+
"""
181+
n_embd = EMBEDDING_DIM
182+
expected_output_std = 1 / math.sqrt(n_embd)
183+
184+
# SwiGLU + RMSNorm + no bias so the Llama3Initializer's FQN regexes fully match
185+
# the model and it rejects no parameters.
186+
model = create_gpt2_model(
187+
use_weight_tying=use_weight_tying,
188+
activation_type=ActivationType.SWIGLU,
189+
bias=False,
190+
norm_type=LayerNorms.pytorch_rms_norm,
191+
)
192+
initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True, use_weight_tying=use_weight_tying)
193+
# Mirror the production flow (model_factory applies the initializer under no_grad).
194+
with torch.no_grad():
195+
initializer.initialize_in_place(model)
196+
197+
# The logit-producing matrix must be small regardless of weight tying.
198+
output_proj_std = model.transformer.lm_head.weight.detach().float().std().item()
199+
assert output_proj_std == pytest.approx(expected_output_std, rel=0.15)
200+
201+
if use_weight_tying:
202+
# Tied: embedding and output projection are the same (small) tensor.
203+
assert model.transformer.wte.weight is model.transformer.lm_head.weight
204+
else:
205+
# Untied: the embedding keeps the Llama3/TorchTitan std of 1.
206+
embedding_std = model.transformer.wte.weight.detach().float().std().item()
207+
assert embedding_std == pytest.approx(1.0, rel=0.15)
208+
209+
151210
def test_tp_config_allows_untied_word_embeddings():
152211
model = create_gpt2_model(use_weight_tying=False)
153212
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)

0 commit comments

Comments
 (0)