Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,9 +1123,13 @@ def __init__(

@property
def has_tied_word_embeddings(self) -> bool:
token_embedding_weight = getattr(self.transformer.wte, "weight", None)
lm_head_weight = getattr(self.transformer.lm_head, "weight", None)
return token_embedding_weight is not None and token_embedding_weight is lm_head_weight
# In pipeline parallelism a stage's transformer may not contain the wte/lm_head submodules
# (e.g. a middle stage has neither). Such a stage has no tying to report, so return False when
# either submodule is absent. Whether tied embeddings are allowed at all (they are not, for PP)
# is enforced separately by the pipeline/TP config validators on the whole, unsplit model.
if "wte" not in self.transformer or "lm_head" not in self.transformer:
return False
return self.transformer.wte.weight is self.transformer.lm_head.weight

@overload
def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand Down
65 changes: 44 additions & 21 deletions src/modalities/models/gpt2/llama3_like_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
from pydantic import BaseModel, Field

from modalities.models.gpt2.gpt2_model import GPT2LLM
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.utils.logger_utils import get_logger

Expand All @@ -15,7 +16,6 @@
class Llama3InitializerConfig(BaseModel):
num_layers: Annotated[int, Field(strict=True, gt=0)]
n_embd: Annotated[int, Field(strict=True, gt=0)]
use_weight_tying: bool
depth_init: bool = True


Expand All @@ -24,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF):
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
"""

def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None:
def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
"""
Initializes the Llama3Initializer.
Args:
Expand All @@ -35,11 +35,12 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
used for all layers baed on num_layers.
"""
super().__init__()
self.num_layers = num_layers
self.n_embd = n_embd
self.depth_init = depth_init

self.regex_to_init = {
# embedding weights
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
def _build_regex_to_init(self, use_weight_tying: bool) -> dict[str, tuple[Callable, dict]]:
regex_to_init: dict[str, tuple[Callable, dict]] = {
# qkv projections
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
trunc_normal_,
Expand All @@ -57,8 +58,8 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
"mean": 0.0,
"std": (
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
if depth_init
else 0.02 / math.sqrt(2 * num_layers)
if self.depth_init
else 0.02 / math.sqrt(2 * self.num_layers)
),
"a": -2,
"b": 2,
Expand All @@ -80,28 +81,50 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty
"mean": 0.0,
"std": (
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
if depth_init
else 0.02 / math.sqrt(2 * num_layers)
if self.depth_init
else 0.02 / math.sqrt(2 * self.num_layers)
),
"a": -2,
"b": 2,
},
),
}
if not use_weight_tying:
# lm head weights
self.regex_to_init[r"transformer\.lm_head\.weight"] = (
trunc_normal_,
{
"mean": 0.0,
"std": 1 / math.sqrt(n_embd),
"a": -3 / math.sqrt(n_embd),
"b": 3 / math.sqrt(n_embd),
},
)

# Initialization of the output projection (the matrix that produces the logits): small std
# 1/sqrt(n_embd) so the logits are well-scaled at init.
output_projection_init = (
trunc_normal_,
{
"mean": 0.0,
"std": 1 / math.sqrt(self.n_embd),
"a": -3 / math.sqrt(self.n_embd),
"b": 3 / math.sqrt(self.n_embd),
},
)
if use_weight_tying:
# With weight tying, transformer.wte.weight IS the output projection (lm_head shares the
# same tensor), so it must use the small output std instead of the embedding std of 1.
# Otherwise the tied matrix produces logits ~sqrt(n_embd)x too large at init, causing the
# initial loss/grad norm to explode.
regex_to_init[r"transformer\.wte\.weight"] = output_projection_init
else:
# Untied: wte is the embedding (std=1) and lm_head is the separate output projection.
regex_to_init[r"transformer\.wte\.weight"] = (nn.init.normal_, {"mean": 0.0, "std": 1})
regex_to_init[r"transformer\.lm_head\.weight"] = output_projection_init
return regex_to_init

def initialize_in_place(self, model: nn.Module):
self._init_by_fqn_regex(model, self.regex_to_init)
# The FQN regexes are specific to GPT2LLM, which is also the single source of truth for whether
# the word embeddings are tied -- so we infer tying from the model rather than tracking a
# separate flag that could disagree with it (wrong-std tied output projection / uninitialized
# lm_head). Reject model types we cannot initialize.
if not isinstance(model, GPT2LLM):
raise TypeError(
f"Llama3Initializer only supports GPT2LLM (its FQN regexes are specific to it), "
f"but received {type(model).__name__}."
)
regex_to_init = self._build_regex_to_init(use_weight_tying=model.has_tied_word_embeddings)
self._init_by_fqn_regex(model, regex_to_init)

@staticmethod
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]):
Expand Down
94 changes: 81 additions & 13 deletions tests/test_weight_tying.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import math

import pytest
import torch
import torch.nn as nn
from pydantic import ValidationError
from torch.distributed.device_mesh import DeviceMesh

from modalities.config.config import GPT2ModelTPConfig
from modalities.models.components.layer_norms import LayerNormConfig
from modalities.models.components.layer_norms import LayerNormConfig, PytorchRMSLayerNormConfig
from modalities.models.gpt2.gpt2_model import (
GPT2LLM,
AttentionConfig,
Expand All @@ -13,6 +16,7 @@
LayerNormWrapperConfig,
PositionTypes,
)
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer
from modalities.models.model import ActivationType
from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig
from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator
Expand All @@ -27,7 +31,12 @@ def count_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())


def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
def create_gpt2_model(
use_weight_tying: bool,
activation_type: ActivationType = ActivationType.GELU,
bias: bool = True,
norm_type: LayerNorms = LayerNorms.layer_norm,
) -> GPT2LLM:
vocab_size = VOCAB_SIZE
n_embd = EMBEDDING_DIM
sequence_length = 128
Expand All @@ -36,9 +45,7 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
n_head_kv = 2
ffn_hidden = 256
dropout = 0.1
bias = True
poe_type = PositionTypes.NOPE
activation_type = ActivationType.GELU
attention_implementation = AttentionImplementation.PYTORCH_FLASH
attention_config = AttentionConfig(
qkv_transforms=[
Expand All @@ -53,15 +60,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM:
)
]
)
attention_norm_config = LayerNormWrapperConfig(
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
)
ffn_norm_config = LayerNormWrapperConfig(
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
)
lm_head_norm_config = LayerNormWrapperConfig(
norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd)
)

def _make_norm_config() -> LayerNormWrapperConfig:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Low (cleanup) — test helper re-encodes a mapping the enum already owns.

_make_norm_config hand-maps pytorch_rms_norm -> PytorchRMSLayerNormConfig and everything else -> LayerNormConfig. The LayerNorms enum already pairs each norm type with its config class; a future test passing a third norm type (e.g. rms_norm) silently falls through to the wrong config and produces a confusing construction error rather than using the correct config class. Minor, test-only.

if norm_type == LayerNorms.pytorch_rms_norm:
return LayerNormWrapperConfig(
norm_type=norm_type, config=PytorchRMSLayerNormConfig(normalized_shape=n_embd)
)
return LayerNormWrapperConfig(norm_type=norm_type, config=LayerNormConfig(normalized_shape=n_embd))

attention_norm_config = _make_norm_config()
ffn_norm_config = _make_norm_config()
lm_head_norm_config = _make_norm_config()

return GPT2LLM(
sample_key="input_ids",
Expand Down Expand Up @@ -140,6 +149,17 @@ def test_has_tied_word_embeddings_requires_model_capability():
has_tied_word_embeddings(nn.Linear(1, 1))


@pytest.mark.parametrize("module_name", ["wte", "lm_head"])
def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str):
# In pipeline parallelism a stage's transformer ModuleDict only contains the submodules assigned
# to that stage (the transformer container itself is always present), so a stage may lack wte
# and/or lm_head. Such a stage has no tying to report and must not raise.
model = create_gpt2_model(use_weight_tying=True)
del model.transformer[module_name]

assert has_tied_word_embeddings(model) is False


def test_tp_config_rejects_tied_word_embeddings():
model = create_gpt2_model(use_weight_tying=True)
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)
Expand All @@ -148,6 +168,54 @@ def test_tp_config_rejects_tied_word_embeddings():
GPT2ModelTPConfig(model=model, device_mesh=device_mesh)


@pytest.mark.parametrize("use_weight_tying", [True, False])
def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool):
"""Regression test for the weight-tying init bug.

With weight tying, ``transformer.wte.weight`` *is* the output projection
(``lm_head`` shares the same tensor), so it must be initialized with the small
output std ``1 / sqrt(n_embd)`` -- not the embedding std of 1. Otherwise the tied
matrix produces logits ~sqrt(n_embd)x too large at init and the loss/grad norm
explode (observed: initial loss ~1685 instead of ~ln(vocab_size)).
"""
n_embd = EMBEDDING_DIM
expected_output_std = 1 / math.sqrt(n_embd)

# SwiGLU + RMSNorm + no bias so the Llama3Initializer's FQN regexes fully match
# the model and it rejects no parameters.
model = create_gpt2_model(
use_weight_tying=use_weight_tying,
activation_type=ActivationType.SWIGLU,
bias=False,
norm_type=LayerNorms.pytorch_rms_norm,
)
# The initializer infers weight tying from the model itself, so no tying flag is passed.
initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True)
# Mirror the production flow (model_factory applies the initializer under no_grad).
with torch.no_grad():
initializer.initialize_in_place(model)

# The logit-producing matrix must be small regardless of weight tying.
output_proj_std = model.transformer.lm_head.weight.detach().float().std().item()
assert output_proj_std == pytest.approx(expected_output_std, rel=0.15)

if use_weight_tying:
# Tied: embedding and output projection are the same (small) tensor.
assert model.transformer.wte.weight is model.transformer.lm_head.weight
else:
# Untied: the embedding keeps the Llama3/TorchTitan std of 1.
embedding_std = model.transformer.wte.weight.detach().float().std().item()
assert embedding_std == pytest.approx(1.0, rel=0.15)


def test_llama3_init_rejects_non_gpt2_model():
# The FQN regexes are GPT2LLM-specific, so the initializer must reject other model types
# rather than silently leaving everything uninitialized.
initializer = Llama3Initializer(num_layers=2, n_embd=EMBEDDING_DIM, depth_init=True)
with pytest.raises(TypeError, match="only supports GPT2LLM"):
initializer.initialize_in_place(nn.Linear(1, 1))


def test_tp_config_allows_untied_word_embeddings():
model = create_gpt2_model(use_weight_tying=False)
device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)
Expand Down
Empty file.
Loading