Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5f9f50e
fix: Initialize different weights across TP ranks
rrutmann Dec 8, 2025
8c8c5ab
feat: Consider pp rank for model seed
rrutmann Dec 9, 2025
ab3daa0
fix: Only consider PP rank for seeding
rrutmann Dec 10, 2025
62a1743
test: Add test for different parameters on tp/pp ranks
rrutmann Dec 12, 2025
00a595b
test: Check for equal parameters across data parallel processes
rrutmann Dec 12, 2025
bf06da7
feat: Integrate seeding to model initialization
rrutmann Dec 19, 2025
b137701
refactor: Move seeding logic to model initialization component
rrutmann Dec 19, 2025
bff99f3
chore: Add seed and device_mesh to ComposedModelInitializationConfig
rrutmann Dec 19, 2025
98ff9db
test: Adapt test to latest changes
rrutmann Dec 19, 2025
2e248ed
chore: Remove old code
rrutmann Dec 19, 2025
093fa33
chore: Merge branch 'main' into seed
rrutmann May 4, 2026
5a9e89e
fix: Use local-generator weight init
rrutmann May 5, 2026
13e7a82
refactor: Do not set seed in NNModel
rrutmann May 5, 2026
dc11bbb
docs: Add documentation and warning for topology-dependent weight ini…
rrutmann May 5, 2026
999cb65
fix: Fix transformers version mismatch
rrutmann May 5, 2026
b02275f
test: Fix test by removing dependency on global RNG state for seed=None
rrutmann May 5, 2026
ddfbe47
test: Adapt test to latest changes in main
rrutmann May 5, 2026
76762d9
chore: Use consistent typing for optional parameters
rrutmann May 5, 2026
dea2eef
chore: Remove outdated seed parameter
rrutmann May 5, 2026
adf11f0
fix: Use correct type for parameter_name_regexes
rrutmann May 7, 2026
4cf0032
test: Add option for reliable vscode debugging
rrutmann May 7, 2026
7541df2
test: Add test for seeded model reproducibility
rrutmann May 7, 2026
ede150e
chore: Change order of model initialization
rrutmann May 7, 2026
67bc596
feat: Add multi_device_generator_policy for handling seeding with mul…
rrutmann May 7, 2026
5172fc4
refactor: Use enum for multi_device_generator_policy
rrutmann May 8, 2026
326823e
chore: Update model seed initialization
rrutmann May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/components/components.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
|---------------|--------------------|----------------|---------------|---------------------|-------------|
| model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py)| [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place |

The composed initializer supports seeded weight initialization for reproducibility within a fixed topology. When pipeline parallelism is active, Modalities offsets the initialization seed by pipeline stage rank to avoid identical stage-local weights. As a result, the same seed can produce different initialized weights for different pipeline-parallel topologies. For topology-independent reproducibility, create and reuse a distributed checkpoint directly after weight initialization.

## Losses

|Component type | Component Version | Implementation | Configuration | Component Interface | Description |
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from omegaconf import OmegaConf, Resolver
from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator
from torch.distributed.fsdp import ShardingStrategy
from transformers import GPT2TokenizerFast
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
from transformers import GPT2Tokenizer as GPT2TokenizerFast
from transformers import LlamaTokenizer as LlamaTokenizerFast
from typing_extensions import deprecated

from modalities.config.lookup_enum import LookupEnum
Expand Down
9 changes: 8 additions & 1 deletion src/modalities/conversion/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.generic import check_model_inputs

try:
from transformers.utils.generic import check_model_inputs
except ImportError:

def check_model_inputs(func: Callable) -> Callable:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this removed in transormers?
If it is part of a legacy API I think we should also remove this on our end.
What do you think @BlueCrescent? I think you added it, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function was removed in transformers version 5.2. In our pyproject.yaml we specify the requirement "transformers>=4.57.4,<5.0.0", so I used an unsupported transformers version here. Should we remove it just to be on the safe side?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think, we should tackle the transformers 5.0.0+ support soon anyways.

return func


from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config

Expand Down
6 changes: 1 addition & 5 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ class GPT2LLMConfig(BaseModel):
ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network.
lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head.
use_weight_tying (bool): Whether to use weight tying.
seed: Optional[int] = None: The random seed for reproducibility.
enforce_swiglu_hidden_dim_multiple_of (int): If specified, enforces the hidden dimension
in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the
activation_type is SwiGLU. Defaults to 256.
Expand Down Expand Up @@ -370,7 +369,6 @@ class GPT2LLMConfig(BaseModel):
ffn_norm_config: LayerNormWrapperConfig
lm_head_norm_config: LayerNormWrapperConfig
use_weight_tying: bool
seed: Optional[int] = None
enforce_swiglu_hidden_dim_multiple_of: int = 256

@model_validator(mode="after")
Expand Down Expand Up @@ -837,7 +835,6 @@ def __init__(
ffn_norm_config: LayerNormWrapperConfig,
lm_head_norm_config: LayerNormWrapperConfig,
use_weight_tying: bool,
seed: Optional[int] = None,
enforce_swiglu_hidden_dim_multiple_of: int = 256,
):
"""
Expand All @@ -862,7 +859,6 @@ def __init__(
attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module.
ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module.
lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module.
seed (int, optional): The random seed. Defaults to None.
use_weight_tying (bool): Whether to use weight tying.
enforce_swiglu_hidden_dim_multiple_of (int): Enforces
the hidden dimension in the SwiGLU layer to be a multiple of this value.
Expand All @@ -873,7 +869,7 @@ def __init__(
"embedding": [".wte", ".wpe"],
"layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"],
}
super().__init__(weight_decay_groups=weight_decay_groups, seed=seed)
super().__init__(weight_decay_groups=weight_decay_groups)
self.sample_key = sample_key
self.prediction_key = prediction_key
self.sequence_length = sequence_length
Expand Down
5 changes: 1 addition & 4 deletions src/modalities/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,13 @@ class ActivationType(str, Enum):
class NNModel(nn.Module):
"""NNModel class to define a base model."""

def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
def __init__(self, weight_decay_groups: Optional[WeightDecayGroups] = None):
"""
Initializes an NNModel object.

Args:
seed (int, optional): The seed value for random number generation. Defaults to None.
weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.
"""
if seed is not None:
torch.manual_seed(seed)
self._weight_decay_groups = weight_decay_groups if weight_decay_groups is not None else {}
super(NNModel, self).__init__()

Expand Down
2 changes: 0 additions & 2 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def get_gpt2_model(
lm_head_norm_config: LayerNormWrapperConfig,
use_weight_tying: bool,
use_meta_device: Optional[bool] = False,
seed: Optional[int] = None,
enforce_swiglu_hidden_dim_multiple_of: int = 256,
) -> GPT2LLM:
config = dict(
Expand All @@ -637,7 +636,6 @@ def get_gpt2_model(
attention_norm_config=attention_norm_config,
ffn_norm_config=ffn_norm_config,
lm_head_norm_config=lm_head_norm_config,
seed=seed,
use_weight_tying=use_weight_tying,
enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of,
)
Expand Down
60 changes: 55 additions & 5 deletions src/modalities/nn/model_initialization/composed_initialization.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from typing import Optional

import torch
import torch.nn as nn
from pydantic import BaseModel, ConfigDict, Field, model_validator
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import Annotated

from modalities.config.pydantic_if_types import PydanticModelInitializationIFType
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.nn.model_initialization.initialization_routines import InitializationRoutines
from modalities.nn.model_initialization.parameter_name_filters import (
NAMED_PARAMETER_INIT_GROUPS,
SupportWeightInitModels,
WeightInitTypes,
)
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method
from modalities.utils.logger_utils import get_logger

logger = get_logger(__name__)


class ModelInitializerWrapperConfig(BaseModel):
Expand All @@ -30,6 +36,8 @@ class ComposedModelInitializationConfig(BaseModel):
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
seed: int | None = None
device_mesh: Optional[PydanticDeviceMeshIFType] = None

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
Expand Down Expand Up @@ -87,6 +95,24 @@ def initialize_in_place(self, model: nn.Module):


class ComposedInitializationRoutines:
@staticmethod
def _warn_pp_topology_dependent_seed(device_mesh: Optional[DeviceMesh], seed: Optional[int]) -> None:
if seed is None or not has_parallelism_method(
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
):
return

if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return

logger.warning(
"Seeded weight initialization is topology-dependent when pipeline parallelism is active. "
"Modalities offsets the initialization seed by PP rank to avoid identical stage-local weights, "
"so the same seed can produce different initialized weights for different PP configurations. "
"For topology-independent reproducibility, create and reuse a distributed checkpoint directly "
"after weight initialization."
)

@staticmethod
def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF:
initializer_wrapper = ModelInitializerWrapper(model_initializers)
Expand All @@ -98,8 +124,10 @@ def get_composed_model_initializer(
weight_init_type: WeightInitTypes,
mean: float,
std: float | str,
hidden_dim: Optional[int] = None,
num_layers: int = None,
hidden_dim: int | None = None,
num_layers: int | None = None,
device_mesh: Optional[DeviceMesh] = None,
seed: int | None = None,
) -> ModelInitializationIF:
"""This initialization allows to intialize a model with plain, scaled or scaled_embed initialization.
Note that plain initialization is always performed in the beginning. In case of scaled_embed,
Expand All @@ -114,36 +142,58 @@ def get_composed_model_initializer(
Defaults to None.
num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only).
Defaults to None.
device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization.
seed (Optional[int], optional): Seed for random initialization. Defaults to None. When pipeline
parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local
initialization, so the same seed does not guarantee identical initialized weights across different
PP topologies.

Returns:
ModelInitializationIF: The Weight Initializer performing the initialization as specified.
"""
ComposedInitializationRoutines._warn_pp_topology_dependent_seed(device_mesh=device_mesh, seed=seed)

# Set different random seed for each PP rank to ensure diversity
if seed is not None and has_parallelism_method(
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
):
assert device_mesh is not None
seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)
Comment thread
le1nux marked this conversation as resolved.

model_initializers = []

# plain
plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN]
plain_init = InitializationRoutines.get_plain_initialization(
mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes
mean=mean,
std=std,
hidden_dim=hidden_dim,
parameter_name_regexes=plain_parameter_name_regexes,
seed=seed,
)
working_std = plain_init.std
model_initializers.append(plain_init)

if weight_init_type in [WeightInitTypes.SCALED, WeightInitTypes.SCALED_EMBED]:
# scaled
assert num_layers is not None
scaled_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED]
scaled_init = InitializationRoutines.get_scaled_initialization(
mean=mean,
std=working_std,
num_layers=num_layers,
parameter_name_regexes=scaled_parameter_name_regexes,
seed=seed,
)
model_initializers.append(scaled_init)

if weight_init_type == WeightInitTypes.SCALED_EMBED:
# scaled embed
scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED]
scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization(
mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes
mean=mean,
parameter_name_regexes=scaled_embed_parameter_name_regexes,
seed=seed,
)
model_initializers.append(scaled_embed_init)

Expand Down
47 changes: 35 additions & 12 deletions src/modalities/nn/model_initialization/initialization_routines.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import re
from typing import Annotated, Optional
from typing import Annotated

import torch
import torch.nn as nn
from pydantic import BaseModel, Field, model_validator

Expand All @@ -13,7 +14,7 @@ class PlainInitializationConfig(BaseModel):
mean: float
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight"
hidden_dim: Optional[int] = None
hidden_dim: int | None = None

@model_validator(mode="after")
def check_std_and_hidden_dim(self):
Expand All @@ -39,21 +40,32 @@ class ScaledEmbedInitializationConfig(BaseModel):


class NamedParameterwiseNormalInitialization(ModelInitializationIF):
def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter):
def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, seed: int | None = None):
self.mean = mean
self.std = std
self.parameter_name_regexes = parameter_name_regexes
self.seed = torch.initial_seed() if seed is None else seed
self._generators: dict[str, torch.Generator] = {}

def _get_generator(self, parameter: torch.Tensor) -> torch.Generator:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few things are not clear to me.

  1. Do we actually have the case, where in a single process tensors are sitting on different GPUs?
  2. if 1. is the case, then we can end up with tensors that are initialized identically, since we create multiple generators from the same seed.

I'm not sure what the best way to solve this ... also seems to me that the Pytorch API regarding Generators is kinda limited.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we start a single process for each rank via torchrun, this shouldn't happen, right? Or do I miss something?

device_key = str(parameter.device)
generator = self._generators.get(device_key)
if generator is None:
generator = torch.Generator(device=parameter.device)
generator.manual_seed(self.seed)
self._generators[device_key] = generator
return generator

def initialize_in_place(self, model: nn.Module):
weight_regexes = self.parameter_name_regexes.weights
bias_regexes = self.parameter_name_regexes.biases
bias_regexes = self.parameter_name_regexes.biases or []
for parameter_name, p in model.named_parameters():
parameter_name = parameter_name.replace(
"_orig_mod.", ""
) # remove FQN modification from torch.compile if present
for weight_regex in weight_regexes:
if re.fullmatch(weight_regex, parameter_name):
nn.init.normal_(p, mean=self.mean, std=self.std)
nn.init.normal_(p, mean=self.mean, std=self.std, generator=self._get_generator(p))
for bias_regex in bias_regexes:
if re.fullmatch(bias_regex, parameter_name):
nn.init.zeros_(p)
Expand All @@ -62,7 +74,11 @@ def initialize_in_place(self, model: nn.Module):
class InitializationRoutines:
@staticmethod
def get_plain_initialization(
mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None
mean: float,
std: float | str,
parameter_name_regexes: RegexFilter,
hidden_dim: int | None = None,
seed: int | None = None,
) -> NamedParameterwiseNormalInitialization:
"""Initializes the weights of a model by sampling from a normal distribution.
NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers.
Expand All @@ -73,23 +89,26 @@ def get_plain_initialization(
std (float): standard deviation of the normal distribution. If set to "auto", appropiate
value selected as per plain initialization described in https://arxiv.org/abs/2312.16903
hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None.
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
Comment thread
rrutmann marked this conversation as resolved.
should be applied
seed (Optional[int]): Random seed for initialization. Defaults to None.
"""

# auto: choose std automatically
if std == "auto":
if hidden_dim is None:
raise ValueError("ERROR! weight_init.std = auto not implemented")
# as per https://arxiv.org/abs/2312.16903
std = math.sqrt(2 / (5 * hidden_dim))
assert isinstance(std, float)

initialization = NamedParameterwiseNormalInitialization(
mean=mean, std=std, parameter_name_regexes=parameter_name_regexes
mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed
)
return initialization

@staticmethod
def get_scaled_initialization(
mean: float, std: float, num_layers: int, parameter_name_regexes: list[str]
mean: float, std: float, num_layers: int, parameter_name_regexes: RegexFilter, seed: int | None = None
) -> ModelInitializationIF:
"""Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903

Expand All @@ -99,6 +118,7 @@ def get_scaled_initialization(
num_layers (int): Number of layers in the model which we use to downscale std with
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
Comment thread
rrutmann marked this conversation as resolved.
Outdated
should be applied
seed (Optional[int]): Random seed for initialization. Defaults to None.

Returns:
WeightInitializationIF: Weight initialization object
Expand All @@ -107,25 +127,28 @@ def get_scaled_initialization(
scaled_std = std / math.sqrt(2 * num_layers)

initialization = NamedParameterwiseNormalInitialization(
mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes
mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes, seed=seed
)
return initialization

@staticmethod
def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF:
def get_scaled_embed_initialization(
mean: float, parameter_name_regexes: RegexFilter, seed: int | None = None
) -> ModelInitializationIF:
"""Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903
We fix the standard deviation to sqrt(0.4).

Args:
mean (float): Mean of the normal distribution
parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization
should be applied Defaults to None.
seed (Optional[int]): Random seed for initialization. Defaults to None.

Returns:
WeightInitializationIF: Weight initialization object
"""
std = math.sqrt(0.4)
initialization = NamedParameterwiseNormalInitialization(
mean=mean, std=std, parameter_name_regexes=parameter_name_regexes
mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed
)
return initialization
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ model_raw:
component_key: model
variant_key: gpt2
config:
seed: 42
use_meta_device: true
use_weight_tying: false
sample_key: ${settings.referencing_keys.sample_key}
Expand Down
1 change: 0 additions & 1 deletion tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ model_raw:
component_key: model
variant_key: gpt2
config:
seed: 42
use_meta_device: true
use_weight_tying: false
sample_key: ${settings.referencing_keys.sample_key}
Expand Down
Loading
Loading