From 49c185c7dd9d72c86cbc2a58178c5101c393d177 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 20 Jun 2026 17:56:54 +0200 Subject: [PATCH 1/4] fix: fixed initialization of tied weights in Llama3Initializer --- .../models/gpt2/llama3_like_initialization.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index ebbe7fd53..200feca60 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -15,7 +15,7 @@ class Llama3InitializerConfig(BaseModel): num_layers: Annotated[int, Field(strict=True, gt=0)] n_embd: Annotated[int, Field(strict=True, gt=0)] - use_weight_tying: bool + use_weight_tying: bool = False depth_init: bool = True @@ -89,7 +89,7 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty ), } if not use_weight_tying: - # lm head weights + # lm head weights (separate output projection matrix) self.regex_to_init[r"transformer\.lm_head\.weight"] = ( trunc_normal_, { @@ -99,6 +99,21 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty "b": 3 / math.sqrt(n_embd), }, ) + else: + # With weight tying, transformer.wte.weight IS the output projection + # (lm_head shares the same tensor), so it must be initialized with the + # small output std (1/sqrt(n_embd)) instead of the embedding std of 1. + # Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too + # large at init, causing the initial loss/grad norm to explode. + self.regex_to_init[r"transformer\.wte\.weight"] = ( + trunc_normal_, + { + "mean": 0.0, + "std": 1 / math.sqrt(n_embd), + "a": -3 / math.sqrt(n_embd), + "b": 3 / math.sqrt(n_embd), + }, + ) def initialize_in_place(self, model: nn.Module): self._init_by_fqn_regex(model, self.regex_to_init) From 0927a2f539e15b251640f05376d718d224b1d6fb Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 26 Jun 2026 14:51:21 +0200 Subject: [PATCH 2/4] chore: added weight tying tests and Llama3 initialization checks --- .../test_tensor_parallelism.py | 3 +- tests/test_weight_tying.py | 85 ++++++++++++++++--- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index d3ccd46c2..f5d33aae7 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from typing import Tuple @@ -27,7 +28,7 @@ def patch_config_file(original_config_path: Path, activation_type: str, tmp_dir: config_dict["model_raw"]["config"]["activation_type"] = activation_type - tmp_file_path = tmp_dir / original_config_path.name + tmp_file_path = tmp_dir / f"{activation_type}_{os.getpid()}_{original_config_path.name}" with tmp_file_path.open("w", encoding="utf-8") as f: yaml.safe_dump(config_dict, f) diff --git a/tests/test_weight_tying.py b/tests/test_weight_tying.py index 4eb81b1f3..a7c06cb3b 100644 --- a/tests/test_weight_tying.py +++ b/tests/test_weight_tying.py @@ -1,10 +1,13 @@ +import math + import pytest +import torch import torch.nn as nn from pydantic import ValidationError from torch.distributed.device_mesh import DeviceMesh from modalities.config.config import GPT2ModelTPConfig -from modalities.models.components.layer_norms import LayerNormConfig +from modalities.models.components.layer_norms import LayerNormConfig, PytorchRMSLayerNormConfig from modalities.models.gpt2.gpt2_model import ( GPT2LLM, AttentionConfig, @@ -13,6 +16,7 @@ LayerNormWrapperConfig, PositionTypes, ) +from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer from modalities.models.model import ActivationType from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator @@ -27,7 +31,12 @@ def count_parameters(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters()) -def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: +def create_gpt2_model( + use_weight_tying: bool, + activation_type: ActivationType = ActivationType.GELU, + bias: bool = True, + norm_type: LayerNorms = LayerNorms.layer_norm, +) -> GPT2LLM: vocab_size = VOCAB_SIZE n_embd = EMBEDDING_DIM sequence_length = 128 @@ -36,9 +45,7 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: n_head_kv = 2 ffn_hidden = 256 dropout = 0.1 - bias = True poe_type = PositionTypes.NOPE - activation_type = ActivationType.GELU attention_implementation = AttentionImplementation.PYTORCH_FLASH attention_config = AttentionConfig( qkv_transforms=[ @@ -53,15 +60,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: ) ] ) - attention_norm_config = LayerNormWrapperConfig( - norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) - ) - ffn_norm_config = LayerNormWrapperConfig( - norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) - ) - lm_head_norm_config = LayerNormWrapperConfig( - norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) - ) + + def _make_norm_config() -> LayerNormWrapperConfig: + if norm_type == LayerNorms.pytorch_rms_norm: + return LayerNormWrapperConfig( + norm_type=norm_type, config=PytorchRMSLayerNormConfig(normalized_shape=n_embd) + ) + return LayerNormWrapperConfig(norm_type=norm_type, config=LayerNormConfig(normalized_shape=n_embd)) + + attention_norm_config = _make_norm_config() + ffn_norm_config = _make_norm_config() + lm_head_norm_config = _make_norm_config() return GPT2LLM( sample_key="input_ids", @@ -140,6 +149,17 @@ def test_has_tied_word_embeddings_requires_model_capability(): has_tied_word_embeddings(nn.Linear(1, 1)) +@pytest.mark.parametrize("module_name", ["transformer", "wte", "lm_head"]) +def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str): + model = create_gpt2_model(use_weight_tying=True) + if module_name == "transformer": + del model.transformer + else: + del model.transformer[module_name] + + assert has_tied_word_embeddings(model) is False + + def test_tp_config_rejects_tied_word_embeddings(): model = create_gpt2_model(use_weight_tying=True) device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) @@ -148,6 +168,45 @@ def test_tp_config_rejects_tied_word_embeddings(): GPT2ModelTPConfig(model=model, device_mesh=device_mesh) +@pytest.mark.parametrize("use_weight_tying", [True, False]) +def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool): + """Regression test for the weight-tying init bug. + + With weight tying, ``transformer.wte.weight`` *is* the output projection + (``lm_head`` shares the same tensor), so it must be initialized with the small + output std ``1 / sqrt(n_embd)`` -- not the embedding std of 1. Otherwise the tied + matrix produces logits ~sqrt(n_embd)x too large at init and the loss/grad norm + explode (observed: initial loss ~1685 instead of ~ln(vocab_size)). + """ + n_embd = EMBEDDING_DIM + expected_output_std = 1 / math.sqrt(n_embd) + + # SwiGLU + RMSNorm + no bias so the Llama3Initializer's FQN regexes fully match + # the model and it rejects no parameters. + model = create_gpt2_model( + use_weight_tying=use_weight_tying, + activation_type=ActivationType.SWIGLU, + bias=False, + norm_type=LayerNorms.pytorch_rms_norm, + ) + initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True, use_weight_tying=use_weight_tying) + # Mirror the production flow (model_factory applies the initializer under no_grad). + with torch.no_grad(): + initializer.initialize_in_place(model) + + # The logit-producing matrix must be small regardless of weight tying. + output_proj_std = model.transformer.lm_head.weight.detach().float().std().item() + assert output_proj_std == pytest.approx(expected_output_std, rel=0.15) + + if use_weight_tying: + # Tied: embedding and output projection are the same (small) tensor. + assert model.transformer.wte.weight is model.transformer.lm_head.weight + else: + # Untied: the embedding keeps the Llama3/TorchTitan std of 1. + embedding_std = model.transformer.wte.weight.detach().float().std().item() + assert embedding_std == pytest.approx(1.0, rel=0.15) + + def test_tp_config_allows_untied_word_embeddings(): model = create_gpt2_model(use_weight_tying=False) device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) From 6d4d3ca89d3c2e111f86a5d82ee8b2bc681216a0 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sun, 28 Jun 2026 00:06:02 +0200 Subject: [PATCH 3/4] chore: fix has_tied_word_embeddings for pipeline parallelism --- src/modalities/models/gpt2/gpt2_model.py | 10 +++++++--- tests/test_weight_tying.py | 10 +++++----- tutorials/instruction_tuning/experiments/.gitkeep | 0 3 files changed, 12 insertions(+), 8 deletions(-) create mode 100644 tutorials/instruction_tuning/experiments/.gitkeep diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index eb8db53c2..5ebda1502 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -940,9 +940,13 @@ def __init__( @property def has_tied_word_embeddings(self) -> bool: - token_embedding_weight = getattr(self.transformer.wte, "weight", None) - lm_head_weight = getattr(self.transformer.lm_head, "weight", None) - return token_embedding_weight is not None and token_embedding_weight is lm_head_weight + # In pipeline parallelism a stage's transformer may not contain the wte/lm_head submodules + # (e.g. a middle stage has neither). Such a stage has no tying to report, so return False when + # either submodule is absent. Whether tied embeddings are allowed at all (they are not, for PP) + # is enforced separately by the pipeline/TP config validators on the whole, unsplit model. + if "wte" not in self.transformer or "lm_head" not in self.transformer: + return False + return self.transformer.wte.weight is self.transformer.lm_head.weight @overload def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: diff --git a/tests/test_weight_tying.py b/tests/test_weight_tying.py index a7c06cb3b..866196b02 100644 --- a/tests/test_weight_tying.py +++ b/tests/test_weight_tying.py @@ -149,13 +149,13 @@ def test_has_tied_word_embeddings_requires_model_capability(): has_tied_word_embeddings(nn.Linear(1, 1)) -@pytest.mark.parametrize("module_name", ["transformer", "wte", "lm_head"]) +@pytest.mark.parametrize("module_name", ["wte", "lm_head"]) def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str): + # In pipeline parallelism a stage's transformer ModuleDict only contains the submodules assigned + # to that stage (the transformer container itself is always present), so a stage may lack wte + # and/or lm_head. Such a stage has no tying to report and must not raise. model = create_gpt2_model(use_weight_tying=True) - if module_name == "transformer": - del model.transformer - else: - del model.transformer[module_name] + del model.transformer[module_name] assert has_tied_word_embeddings(model) is False diff --git a/tutorials/instruction_tuning/experiments/.gitkeep b/tutorials/instruction_tuning/experiments/.gitkeep new file mode 100644 index 000000000..e69de29bb From 392fe39fb69bdb5033d53172078722ebdc790896 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sun, 28 Jun 2026 16:32:24 +0200 Subject: [PATCH 4/4] fix: update Llama3Initializer to infer weight tying from model and reject non-GPT2 models --- .../models/gpt2/llama3_like_initialization.py | 78 ++++++++++--------- tests/test_weight_tying.py | 11 ++- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index 200feca60..ea6870b16 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -6,6 +6,7 @@ import torch.nn as nn from pydantic import BaseModel, Field +from modalities.models.gpt2.gpt2_model import GPT2LLM from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.utils.logger_utils import get_logger @@ -15,7 +16,6 @@ class Llama3InitializerConfig(BaseModel): num_layers: Annotated[int, Field(strict=True, gt=0)] n_embd: Annotated[int, Field(strict=True, gt=0)] - use_weight_tying: bool = False depth_init: bool = True @@ -24,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF): Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. """ - def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None: + def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: """ Initializes the Llama3Initializer. Args: @@ -35,11 +35,12 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty used for all layers baed on num_layers. """ super().__init__() + self.num_layers = num_layers + self.n_embd = n_embd self.depth_init = depth_init - self.regex_to_init = { - # embedding weights - r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}), + def _build_regex_to_init(self, use_weight_tying: bool) -> dict[str, tuple[Callable, dict]]: + regex_to_init: dict[str, tuple[Callable, dict]] = { # qkv projections r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": ( trunc_normal_, @@ -57,8 +58,8 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty "mean": 0.0, "std": ( (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) - if depth_init - else 0.02 / math.sqrt(2 * num_layers) + if self.depth_init + else 0.02 / math.sqrt(2 * self.num_layers) ), "a": -2, "b": 2, @@ -80,43 +81,50 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty "mean": 0.0, "std": ( (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) - if depth_init - else 0.02 / math.sqrt(2 * num_layers) + if self.depth_init + else 0.02 / math.sqrt(2 * self.num_layers) ), "a": -2, "b": 2, }, ), } - if not use_weight_tying: - # lm head weights (separate output projection matrix) - self.regex_to_init[r"transformer\.lm_head\.weight"] = ( - trunc_normal_, - { - "mean": 0.0, - "std": 1 / math.sqrt(n_embd), - "a": -3 / math.sqrt(n_embd), - "b": 3 / math.sqrt(n_embd), - }, - ) + + # Initialization of the output projection (the matrix that produces the logits): small std + # 1/sqrt(n_embd) so the logits are well-scaled at init. + output_projection_init = ( + trunc_normal_, + { + "mean": 0.0, + "std": 1 / math.sqrt(self.n_embd), + "a": -3 / math.sqrt(self.n_embd), + "b": 3 / math.sqrt(self.n_embd), + }, + ) + if use_weight_tying: + # With weight tying, transformer.wte.weight IS the output projection (lm_head shares the + # same tensor), so it must use the small output std instead of the embedding std of 1. + # Otherwise the tied matrix produces logits ~sqrt(n_embd)x too large at init, causing the + # initial loss/grad norm to explode. + regex_to_init[r"transformer\.wte\.weight"] = output_projection_init else: - # With weight tying, transformer.wte.weight IS the output projection - # (lm_head shares the same tensor), so it must be initialized with the - # small output std (1/sqrt(n_embd)) instead of the embedding std of 1. - # Otherwise the tied matrix produces logits that are ~sqrt(n_embd)x too - # large at init, causing the initial loss/grad norm to explode. - self.regex_to_init[r"transformer\.wte\.weight"] = ( - trunc_normal_, - { - "mean": 0.0, - "std": 1 / math.sqrt(n_embd), - "a": -3 / math.sqrt(n_embd), - "b": 3 / math.sqrt(n_embd), - }, - ) + # Untied: wte is the embedding (std=1) and lm_head is the separate output projection. + regex_to_init[r"transformer\.wte\.weight"] = (nn.init.normal_, {"mean": 0.0, "std": 1}) + regex_to_init[r"transformer\.lm_head\.weight"] = output_projection_init + return regex_to_init def initialize_in_place(self, model: nn.Module): - self._init_by_fqn_regex(model, self.regex_to_init) + # The FQN regexes are specific to GPT2LLM, which is also the single source of truth for whether + # the word embeddings are tied -- so we infer tying from the model rather than tracking a + # separate flag that could disagree with it (wrong-std tied output projection / uninitialized + # lm_head). Reject model types we cannot initialize. + if not isinstance(model, GPT2LLM): + raise TypeError( + f"Llama3Initializer only supports GPT2LLM (its FQN regexes are specific to it), " + f"but received {type(model).__name__}." + ) + regex_to_init = self._build_regex_to_init(use_weight_tying=model.has_tied_word_embeddings) + self._init_by_fqn_regex(model, regex_to_init) @staticmethod def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]): diff --git a/tests/test_weight_tying.py b/tests/test_weight_tying.py index 866196b02..885e8bb08 100644 --- a/tests/test_weight_tying.py +++ b/tests/test_weight_tying.py @@ -189,7 +189,8 @@ def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool): bias=False, norm_type=LayerNorms.pytorch_rms_norm, ) - initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True, use_weight_tying=use_weight_tying) + # The initializer infers weight tying from the model itself, so no tying flag is passed. + initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True) # Mirror the production flow (model_factory applies the initializer under no_grad). with torch.no_grad(): initializer.initialize_in_place(model) @@ -207,6 +208,14 @@ def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool): assert embedding_std == pytest.approx(1.0, rel=0.15) +def test_llama3_init_rejects_non_gpt2_model(): + # The FQN regexes are GPT2LLM-specific, so the initializer must reject other model types + # rather than silently leaving everything uninitialized. + initializer = Llama3Initializer(num_layers=2, n_embd=EMBEDDING_DIM, depth_init=True) + with pytest.raises(TypeError, match="only supports GPT2LLM"): + initializer.initialize_in_place(nn.Linear(1, 1)) + + def test_tp_config_allows_untied_word_embeddings(): model = create_gpt2_model(use_weight_tying=False) device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value)