diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 701a23850..993221e2c 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1123,9 +1123,13 @@ def __init__( @property def has_tied_word_embeddings(self) -> bool: - token_embedding_weight = getattr(self.transformer.wte, "weight", None) - lm_head_weight = getattr(self.transformer.lm_head, "weight", None) - return token_embedding_weight is not None and token_embedding_weight is lm_head_weight + # In pipeline parallelism a stage's transformer may not contain the wte/lm_head submodules + # (e.g. a middle stage has neither). Such a stage has no tying to report, so return False when + # either submodule is absent. Whether tied embeddings are allowed at all (they are not, for PP) + # is enforced separately by the pipeline/TP config validators on the whole, unsplit model. + if "wte" not in self.transformer or "lm_head" not in self.transformer: + return False + return self.transformer.wte.weight is self.transformer.lm_head.weight @overload def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index ebbe7fd53..ea6870b16 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -6,6 +6,7 @@ import torch.nn as nn from pydantic import BaseModel, Field +from modalities.models.gpt2.gpt2_model import GPT2LLM from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.utils.logger_utils import get_logger @@ -15,7 +16,6 @@ class Llama3InitializerConfig(BaseModel): num_layers: Annotated[int, Field(strict=True, gt=0)] n_embd: Annotated[int, Field(strict=True, gt=0)] - use_weight_tying: bool depth_init: bool = True @@ -24,7 +24,7 @@ class Llama3Initializer(ModelInitializationIF): Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. """ - def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None: + def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: """ Initializes the Llama3Initializer. Args: @@ -35,11 +35,12 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty used for all layers baed on num_layers. """ super().__init__() + self.num_layers = num_layers + self.n_embd = n_embd self.depth_init = depth_init - self.regex_to_init = { - # embedding weights - r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}), + def _build_regex_to_init(self, use_weight_tying: bool) -> dict[str, tuple[Callable, dict]]: + regex_to_init: dict[str, tuple[Callable, dict]] = { # qkv projections r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": ( trunc_normal_, @@ -57,8 +58,8 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty "mean": 0.0, "std": ( (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) - if depth_init - else 0.02 / math.sqrt(2 * num_layers) + if self.depth_init + else 0.02 / math.sqrt(2 * self.num_layers) ), "a": -2, "b": 2, @@ -80,28 +81,50 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_ty "mean": 0.0, "std": ( (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) - if depth_init - else 0.02 / math.sqrt(2 * num_layers) + if self.depth_init + else 0.02 / math.sqrt(2 * self.num_layers) ), "a": -2, "b": 2, }, ), } - if not use_weight_tying: - # lm head weights - self.regex_to_init[r"transformer\.lm_head\.weight"] = ( - trunc_normal_, - { - "mean": 0.0, - "std": 1 / math.sqrt(n_embd), - "a": -3 / math.sqrt(n_embd), - "b": 3 / math.sqrt(n_embd), - }, - ) + + # Initialization of the output projection (the matrix that produces the logits): small std + # 1/sqrt(n_embd) so the logits are well-scaled at init. + output_projection_init = ( + trunc_normal_, + { + "mean": 0.0, + "std": 1 / math.sqrt(self.n_embd), + "a": -3 / math.sqrt(self.n_embd), + "b": 3 / math.sqrt(self.n_embd), + }, + ) + if use_weight_tying: + # With weight tying, transformer.wte.weight IS the output projection (lm_head shares the + # same tensor), so it must use the small output std instead of the embedding std of 1. + # Otherwise the tied matrix produces logits ~sqrt(n_embd)x too large at init, causing the + # initial loss/grad norm to explode. + regex_to_init[r"transformer\.wte\.weight"] = output_projection_init + else: + # Untied: wte is the embedding (std=1) and lm_head is the separate output projection. + regex_to_init[r"transformer\.wte\.weight"] = (nn.init.normal_, {"mean": 0.0, "std": 1}) + regex_to_init[r"transformer\.lm_head\.weight"] = output_projection_init + return regex_to_init def initialize_in_place(self, model: nn.Module): - self._init_by_fqn_regex(model, self.regex_to_init) + # The FQN regexes are specific to GPT2LLM, which is also the single source of truth for whether + # the word embeddings are tied -- so we infer tying from the model rather than tracking a + # separate flag that could disagree with it (wrong-std tied output projection / uninitialized + # lm_head). Reject model types we cannot initialize. + if not isinstance(model, GPT2LLM): + raise TypeError( + f"Llama3Initializer only supports GPT2LLM (its FQN regexes are specific to it), " + f"but received {type(model).__name__}." + ) + regex_to_init = self._build_regex_to_init(use_weight_tying=model.has_tied_word_embeddings) + self._init_by_fqn_regex(model, regex_to_init) @staticmethod def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]): diff --git a/tests/test_weight_tying.py b/tests/test_weight_tying.py index 4eb81b1f3..885e8bb08 100644 --- a/tests/test_weight_tying.py +++ b/tests/test_weight_tying.py @@ -1,10 +1,13 @@ +import math + import pytest +import torch import torch.nn as nn from pydantic import ValidationError from torch.distributed.device_mesh import DeviceMesh from modalities.config.config import GPT2ModelTPConfig -from modalities.models.components.layer_norms import LayerNormConfig +from modalities.models.components.layer_norms import LayerNormConfig, PytorchRMSLayerNormConfig from modalities.models.gpt2.gpt2_model import ( GPT2LLM, AttentionConfig, @@ -13,6 +16,7 @@ LayerNormWrapperConfig, PositionTypes, ) +from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer from modalities.models.model import ActivationType from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator @@ -27,7 +31,12 @@ def count_parameters(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters()) -def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: +def create_gpt2_model( + use_weight_tying: bool, + activation_type: ActivationType = ActivationType.GELU, + bias: bool = True, + norm_type: LayerNorms = LayerNorms.layer_norm, +) -> GPT2LLM: vocab_size = VOCAB_SIZE n_embd = EMBEDDING_DIM sequence_length = 128 @@ -36,9 +45,7 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: n_head_kv = 2 ffn_hidden = 256 dropout = 0.1 - bias = True poe_type = PositionTypes.NOPE - activation_type = ActivationType.GELU attention_implementation = AttentionImplementation.PYTORCH_FLASH attention_config = AttentionConfig( qkv_transforms=[ @@ -53,15 +60,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: ) ] ) - attention_norm_config = LayerNormWrapperConfig( - norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) - ) - ffn_norm_config = LayerNormWrapperConfig( - norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) - ) - lm_head_norm_config = LayerNormWrapperConfig( - norm_type=LayerNorms.layer_norm, config=LayerNormConfig(normalized_shape=n_embd) - ) + + def _make_norm_config() -> LayerNormWrapperConfig: + if norm_type == LayerNorms.pytorch_rms_norm: + return LayerNormWrapperConfig( + norm_type=norm_type, config=PytorchRMSLayerNormConfig(normalized_shape=n_embd) + ) + return LayerNormWrapperConfig(norm_type=norm_type, config=LayerNormConfig(normalized_shape=n_embd)) + + attention_norm_config = _make_norm_config() + ffn_norm_config = _make_norm_config() + lm_head_norm_config = _make_norm_config() return GPT2LLM( sample_key="input_ids", @@ -140,6 +149,17 @@ def test_has_tied_word_embeddings_requires_model_capability(): has_tied_word_embeddings(nn.Linear(1, 1)) +@pytest.mark.parametrize("module_name", ["wte", "lm_head"]) +def test_has_tied_word_embeddings_handles_pipeline_stage(module_name: str): + # In pipeline parallelism a stage's transformer ModuleDict only contains the submodules assigned + # to that stage (the transformer container itself is always present), so a stage may lack wte + # and/or lm_head. Such a stage has no tying to report and must not raise. + model = create_gpt2_model(use_weight_tying=True) + del model.transformer[module_name] + + assert has_tied_word_embeddings(model) is False + + def test_tp_config_rejects_tied_word_embeddings(): model = create_gpt2_model(use_weight_tying=True) device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) @@ -148,6 +168,54 @@ def test_tp_config_rejects_tied_word_embeddings(): GPT2ModelTPConfig(model=model, device_mesh=device_mesh) +@pytest.mark.parametrize("use_weight_tying", [True, False]) +def test_llama3_init_keeps_output_projection_small(use_weight_tying: bool): + """Regression test for the weight-tying init bug. + + With weight tying, ``transformer.wte.weight`` *is* the output projection + (``lm_head`` shares the same tensor), so it must be initialized with the small + output std ``1 / sqrt(n_embd)`` -- not the embedding std of 1. Otherwise the tied + matrix produces logits ~sqrt(n_embd)x too large at init and the loss/grad norm + explode (observed: initial loss ~1685 instead of ~ln(vocab_size)). + """ + n_embd = EMBEDDING_DIM + expected_output_std = 1 / math.sqrt(n_embd) + + # SwiGLU + RMSNorm + no bias so the Llama3Initializer's FQN regexes fully match + # the model and it rejects no parameters. + model = create_gpt2_model( + use_weight_tying=use_weight_tying, + activation_type=ActivationType.SWIGLU, + bias=False, + norm_type=LayerNorms.pytorch_rms_norm, + ) + # The initializer infers weight tying from the model itself, so no tying flag is passed. + initializer = Llama3Initializer(num_layers=2, n_embd=n_embd, depth_init=True) + # Mirror the production flow (model_factory applies the initializer under no_grad). + with torch.no_grad(): + initializer.initialize_in_place(model) + + # The logit-producing matrix must be small regardless of weight tying. + output_proj_std = model.transformer.lm_head.weight.detach().float().std().item() + assert output_proj_std == pytest.approx(expected_output_std, rel=0.15) + + if use_weight_tying: + # Tied: embedding and output projection are the same (small) tensor. + assert model.transformer.wte.weight is model.transformer.lm_head.weight + else: + # Untied: the embedding keeps the Llama3/TorchTitan std of 1. + embedding_std = model.transformer.wte.weight.detach().float().std().item() + assert embedding_std == pytest.approx(1.0, rel=0.15) + + +def test_llama3_init_rejects_non_gpt2_model(): + # The FQN regexes are GPT2LLM-specific, so the initializer must reject other model types + # rather than silently leaving everything uninitialized. + initializer = Llama3Initializer(num_layers=2, n_embd=EMBEDDING_DIM, depth_init=True) + with pytest.raises(TypeError, match="only supports GPT2LLM"): + initializer.initialize_in_place(nn.Linear(1, 1)) + + def test_tp_config_allows_untied_word_embeddings(): model = create_gpt2_model(use_weight_tying=False) device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) diff --git a/tutorials/instruction_tuning/experiments/.gitkeep b/tutorials/instruction_tuning/experiments/.gitkeep new file mode 100644 index 000000000..e69de29bb