From 5f9f50e228584e6b355998ac03c4d33873a8b0ca Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 8 Dec 2025 14:18:31 +0000 Subject: [PATCH 01/25] fix: Initialize different weights across TP ranks --- src/modalities/models/gpt2/gpt2_model.py | 11 +++++++++++ src/modalities/models/model_factory.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 0a846b38a..66e19376f 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -7,8 +7,10 @@ import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator, validator +from torch.distributed.device_mesh import DeviceMesh from modalities.config.lookup_enum import LookupEnum +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType from modalities.config.utils import convert_base_model_config_to_dict from modalities.models.components.layer_norms import ( LayerNormConfig, @@ -17,6 +19,7 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method from modalities.util import parse_enum_by_name try: @@ -367,6 +370,7 @@ class GPT2LLMConfig(BaseModel): use_weight_tying: bool seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 + device_mesh: Optional[PydanticDeviceMeshIFType] = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -834,6 +838,7 @@ def __init__( use_weight_tying: bool, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, + device_mesh: DeviceMesh | None = None, ): """ Initializes the GPT2LLM object. @@ -862,12 +867,18 @@ def __init__( enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. + device_mesh (DeviceMesh | None): The device mesh for parallelism. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } + # Set different random seed for each TP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 3acb17f95..c4c953eaf 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -578,6 +578,7 @@ def get_gpt2_model( use_meta_device: Optional[bool] = False, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, + device_mesh: DeviceMesh | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -601,6 +602,7 @@ def get_gpt2_model( seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, + device_mesh=device_mesh, ) if use_meta_device and use_weight_tying: raise ValueError( From 8c8c5abb716e86bfba86429462152944a890864a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Dec 2025 09:36:00 +0000 Subject: [PATCH 02/25] feat: Consider pp rank for model seed --- src/modalities/models/gpt2/gpt2_model.py | 38 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 66e19376f..dd5dbd3ec 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -19,7 +19,12 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method +from modalities.running_env.fsdp.device_mesh import ( + ParallelismDegrees, + get_parallel_degree, + get_parallel_rank, + has_parallelism_method, +) from modalities.util import parse_enum_by_name try: @@ -874,11 +879,9 @@ def __init__( "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - # Set different random seed for each TP rank to ensure diversity - if seed is not None and has_parallelism_method( - device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP - ): - seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) + # Set different random seed for each TP and PP rank to ensure diversity + if seed is not None and device_mesh is not None: + seed = _offset_seed_by_parallel_ranks(seed=seed, device_mesh=device_mesh) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key @@ -1069,3 +1072,26 @@ def manual_scaled_dot_product_attention( attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value + + +def _offset_seed_by_parallel_ranks(seed: int, device_mesh: DeviceMesh) -> int: + """ + Return a seed shifted by the TP/PP ranks so each TP/PP pair produces a distinct value. + """ + tp_rank = None + pp_rank = None + pp_degree = 1 + + if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP): + tp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) + if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP): + pp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) + pp_degree = get_parallel_degree(device_mesh=device_mesh, parallelism_methods=[ParallelismDegrees.PP]) + + if tp_rank is not None and pp_rank is not None: + return seed + tp_rank * pp_degree + pp_rank + if tp_rank is not None: + return seed + tp_rank + if pp_rank is not None: + return seed + pp_rank + return seed From ab3daa01a02adff9e20d2aed1b56b28845b1ec1c Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Dec 2025 09:43:15 +0000 Subject: [PATCH 03/25] fix: Only consider PP rank for seeding --- src/modalities/models/gpt2/gpt2_model.py | 38 ++++-------------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index dd5dbd3ec..c8f82ecf6 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -19,12 +19,7 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU -from modalities.running_env.fsdp.device_mesh import ( - ParallelismDegrees, - get_parallel_degree, - get_parallel_rank, - has_parallelism_method, -) +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method from modalities.util import parse_enum_by_name try: @@ -879,9 +874,11 @@ def __init__( "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - # Set different random seed for each TP and PP rank to ensure diversity - if seed is not None and device_mesh is not None: - seed = _offset_seed_by_parallel_ranks(seed=seed, device_mesh=device_mesh) + # Set different random seed for each PP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key @@ -1072,26 +1069,3 @@ def manual_scaled_dot_product_attention( attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value - - -def _offset_seed_by_parallel_ranks(seed: int, device_mesh: DeviceMesh) -> int: - """ - Return a seed shifted by the TP/PP ranks so each TP/PP pair produces a distinct value. - """ - tp_rank = None - pp_rank = None - pp_degree = 1 - - if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP): - tp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) - if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP): - pp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) - pp_degree = get_parallel_degree(device_mesh=device_mesh, parallelism_methods=[ParallelismDegrees.PP]) - - if tp_rank is not None and pp_rank is not None: - return seed + tp_rank * pp_degree + pp_rank - if tp_rank is not None: - return seed + tp_rank - if pp_rank is not None: - return seed + pp_rank - return seed From 62a1743dcf6561f83c6ff5f545aa79a23bbfd2b0 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 12 Dec 2025 17:21:56 +0000 Subject: [PATCH 04/25] test: Add test for different parameters on tp/pp ranks --- .../test_parallel_seed_initialization.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 tests/fsdp2_parallelization/test_parallel_seed_initialization.py diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py new file mode 100644 index 000000000..58f0d10c5 --- /dev/null +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -0,0 +1,169 @@ +import logging +import multiprocessing as py_mp +import os +import re +import traceback +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType +from modalities.logging_broker.messages import Message +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_parallel_rank +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from tests.utility import monitor_child_processes + +working_dir = Path(os.path.dirname(__file__)) +tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp" +working_dir = working_dir / "configs" + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This e2e test requires 8 GPUs.", +) +class TestParallelSeedInitialization: + WORLD_SIZE = 8 + RDVZ_PORT = 24574 + + def test_parameters_follow_parallelism(self, tmp_path: Path): + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + self._seed_distribution_impl_wrapper, + args=(self.WORLD_SIZE, tmp_path, error_queue), + nprocs=self.WORLD_SIZE, + join=False, + ) + monitor_child_processes(manager, error_queue, proc_ctx) + + def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_path: Path, error_queue: Any): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=TestParallelSeedInitialization.RDVZ_PORT, + ): + try: + self._seed_distribution_impl(world_size=world_size, tmp_path=tmp_path) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} (seed distribution test) encountered an error:\n{exc}") + logging.error(tb) + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to put exception info into error queue (seed distribution test).") + os._exit(1) + + def _seed_distribution_impl(self, world_size: int, tmp_path: Path): + device_mesh = get_device_mesh( + device_type="cuda", + data_parallel_replicate_degree=2, + data_parallel_shard_degree=1, + tensor_parallel_degree=2, + pipeline_parallel_degree=2, + context_parallel_degree=1, + enable_loss_parallel=False, + world_size=world_size, + ) + + # initialize components + class ComponentsInstantiationModel(BaseModel): + fsdp_model: PydanticFSDP2ModuleType + device_mesh: PydanticDeviceMeshIFType + + config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path) + main_obj = Main(config_file_path) + components = main_obj.build_components(components_model_type=ComponentsInstantiationModel) + model = components.fsdp_model + device_mesh = components.device_mesh + # get first transformer block's MLP weight parameter shards + block_key = next(iter(model.transformer.h.keys())) + block = model.transformer.h[block_key] + payload = { + "tensor_shard": block.mlp.W.weight.to_local().cpu(), + "tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP), + "pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP), + "dp_shard_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.DP_SHARD), + "block_key": block_key, + } + + gather_list: list[dict[str, Any]] | None = [None] * world_size if dist.get_rank() == 0 else None + dist.gather_object(payload, gather_list, dst=0) + + if dist.get_rank() == 0: + assert gather_list is not None + TestParallelSeedInitialization._assert_parameter_distribution(gather_list) + dist.barrier() + + @staticmethod + def _assert_parameter_distribution(records: list[dict[str, Any]]): + combos: dict[tuple[int, int], list[dict[str, Any]]] = {} + for record in records: + key = (record["pp_rank"], record["tp_rank"]) + combos.setdefault(key, []).append(record) + + expected_combo_count = 4 + assert ( + len(combos) == expected_combo_count + ), f"Expected {expected_combo_count} PP/TP combinations, got {len(combos)}" + + combo_tensors: dict[tuple[int, int], torch.Tensor] = {} + for (pp_rank, tp_rank), entries in combos.items(): + shards = sorted(entries, key=lambda e: e["dp_shard_rank"]) + combo_tensors[(pp_rank, tp_rank)] = torch.cat( + [e["tensor_shard"] for e in shards], + dim=0, + ) + + combo_items = list(combo_tensors.items()) + for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items): + for other_key, other_tensor in combo_items[idx + 1 :]: + tensors_equal = torch.equal(base_tensor, other_tensor) + assert not tensors_equal, ( + "Distinct TP/PP combinations should initialize with different weights; " + f"found match between (PP={pp_rank}, TP={tp_rank}) and (PP={other_key[0]}, TP={other_key[1]})" + ) + + def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degree: int, tmp_path: Path) -> Path: + temp_file_path = tmp_path / "pp_tp_sharding_config.yaml" + working_dir = Path(os.path.dirname(__file__)) + config_file_path = ( + working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml" + ) + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = dp_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + +def _get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: + return [message.payload.losses[loss_key].value.item() for message in messages] + + +def _extract_seen_steps_and_tokens(filename: str) -> tuple[int, int]: + pattern = r"seen_steps_(\d+)-seen_tokens_(\d+)" + match = re.search(pattern, filename) + if match is None: + raise ValueError(f"Filename '{filename}' does not match expected pattern '{pattern}'.") + return int(match.group(1)), int(match.group(2)) From 00a595bfcbf136f11eebc9a808bee09a304aaa51 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 12 Dec 2025 17:49:32 +0000 Subject: [PATCH 05/25] test: Check for equal parameters across data parallel processes --- .../test_parallel_seed_initialization.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index 58f0d10c5..45b79327f 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -12,13 +12,14 @@ import torch.multiprocessing as mp import yaml from pydantic import BaseModel +from torch.distributed._tensor.placement_types import Replicate from modalities.__main__ import Main from modalities.batch import EvaluationResultBatch from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType from modalities.logging_broker.messages import Message -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_parallel_rank +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import monitor_child_processes @@ -67,17 +68,6 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_ os._exit(1) def _seed_distribution_impl(self, world_size: int, tmp_path: Path): - device_mesh = get_device_mesh( - device_type="cuda", - data_parallel_replicate_degree=2, - data_parallel_shard_degree=1, - tensor_parallel_degree=2, - pipeline_parallel_degree=2, - context_parallel_degree=1, - enable_loss_parallel=False, - world_size=world_size, - ) - # initialize components class ComponentsInstantiationModel(BaseModel): fsdp_model: PydanticFSDP2ModuleType @@ -88,10 +78,13 @@ class ComponentsInstantiationModel(BaseModel): components = main_obj.build_components(components_model_type=ComponentsInstantiationModel) model = components.fsdp_model device_mesh = components.device_mesh - # get first transformer block's MLP weight parameter shards + # for each pp stage get first transformer block's MLP weight parameter shards and full tensor block_key = next(iter(model.transformer.h.keys())) block = model.transformer.h[block_key] + placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape) + full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu() payload = { + "tensor_full": full_weight, "tensor_shard": block.mlp.W.weight.to_local().cpu(), "tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP), "pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP), @@ -121,12 +114,24 @@ def _assert_parameter_distribution(records: list[dict[str, Any]]): combo_tensors: dict[tuple[int, int], torch.Tensor] = {} for (pp_rank, tp_rank), entries in combos.items(): + # check that full tensors are the same across data parallel processes + reference = entries[0]["tensor_full"] + seen_dp_ranks: set[int] = set() + for entry in entries: + dp_rank = entry["dp_shard_rank"] + assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}" + seen_dp_ranks.add(dp_rank) + assert torch.equal(reference, entry["tensor_full"]), ( + "Tensors within the same TP/PP combo must be identical across DP ranks; " + f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})" + ) + # concatenate all shards for this pp/tp combo shards = sorted(entries, key=lambda e: e["dp_shard_rank"]) combo_tensors[(pp_rank, tp_rank)] = torch.cat( [e["tensor_shard"] for e in shards], dim=0, ) - + # check that tensor shards differ across different pp/tp combos combo_items = list(combo_tensors.items()) for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items): for other_key, other_tensor in combo_items[idx + 1 :]: From bf06da7bf67997e20f72949cab849597ad0d7508 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 13:26:49 +0000 Subject: [PATCH 06/25] feat: Integrate seeding to model initialization --- .../composed_initialization.py | 23 ++++++++++++++-- .../initialization_routines.py | 27 ++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 190311cb6..b1b976573 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -2,6 +2,7 @@ import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator +from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Annotated from modalities.config.pydantic_if_types import PydanticModelInitializationIFType @@ -12,6 +13,7 @@ SupportWeightInitModels, WeightInitTypes, ) +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method class ModelInitializerWrapperConfig(BaseModel): @@ -100,6 +102,8 @@ def get_composed_model_initializer( std: float | str, hidden_dim: Optional[int] = None, num_layers: int = None, + device_mesh: Optional[DeviceMesh] = None, + seed: Optional[int] = None, ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, @@ -114,16 +118,28 @@ def get_composed_model_initializer( Defaults to None. num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only). Defaults to None. + device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization. + seed (Optional[int], optional): Seed for random initialization. Defaults to None. Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. """ + # Set different random seed for each PP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) + model_initializers = [] # plain plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN] plain_init = InitializationRoutines.get_plain_initialization( - mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes + mean=mean, + std=std, + hidden_dim=hidden_dim, + parameter_name_regexes=plain_parameter_name_regexes, + seed=seed, ) working_std = plain_init.std model_initializers.append(plain_init) @@ -136,6 +152,7 @@ def get_composed_model_initializer( std=working_std, num_layers=num_layers, parameter_name_regexes=scaled_parameter_name_regexes, + seed=seed, ) model_initializers.append(scaled_init) @@ -143,7 +160,9 @@ def get_composed_model_initializer( # scaled embed scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED] scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization( - mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes + mean=mean, + parameter_name_regexes=scaled_embed_parameter_name_regexes, + seed=seed, ) model_initializers.append(scaled_embed_init) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 5b4515875..36953d646 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -2,6 +2,7 @@ import re from typing import Annotated, Optional +import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator @@ -59,7 +60,11 @@ def initialize_in_place(self, model: nn.Module): class InitializationRoutines: @staticmethod def get_plain_initialization( - mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None + mean: float, + std: float | str, + parameter_name_regexes: list[str], + hidden_dim: Optional[int] = None, + seed: Optional[int] = None, ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -70,8 +75,11 @@ def get_plain_initialization( std (float): standard deviation of the normal distribution. If set to "auto", appropiate value selected as per plain initialization described in https://arxiv.org/abs/2312.16903 hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None. + parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization + should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. """ - + InitializationRoutines._set_seed(seed) # auto: choose std automatically if std == "auto": if hidden_dim is None: @@ -86,7 +94,7 @@ def get_plain_initialization( @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] + mean: float, std: float, num_layers: int, parameter_name_regexes: list[str], seed: Optional[int] = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -96,10 +104,12 @@ def get_scaled_initialization( num_layers (int): Number of layers in the model which we use to downscale std with parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. Returns: WeightInitializationIF: Weight initialization object """ + InitializationRoutines._set_seed(seed) # see https://arxiv.org/abs/2312.16903 scaled_std = std / math.sqrt(2 * num_layers) @@ -109,7 +119,9 @@ def get_scaled_initialization( return initialization @staticmethod - def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: + def get_scaled_embed_initialization( + mean: float, parameter_name_regexes: list[str], seed: Optional[int] = None + ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -117,12 +129,19 @@ def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[st mean (float): Mean of the normal distribution parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. + seed (Optional[int]): Random seed for initialization. Defaults to None. Returns: WeightInitializationIF: Weight initialization object """ + InitializationRoutines._set_seed(seed) std = math.sqrt(0.4) initialization = NamedParameterwiseNormalInitialization( mean=mean, std=std, parameter_name_regexes=parameter_name_regexes ) return initialization + + @staticmethod + def _set_seed(seed: Optional[int]): + if seed is not None: + torch.manual_seed(seed) From b137701774c5e0bae687231865091a3c4a39b01d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 13:37:01 +0000 Subject: [PATCH 07/25] refactor: Move seeding logic to model initialization component --- src/modalities/models/gpt2/gpt2_model.py | 11 ----------- src/modalities/models/model.py | 2 +- src/modalities/models/model_factory.py | 2 -- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index c8f82ecf6..0a846b38a 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -7,10 +7,8 @@ import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator, validator -from torch.distributed.device_mesh import DeviceMesh from modalities.config.lookup_enum import LookupEnum -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType from modalities.config.utils import convert_base_model_config_to_dict from modalities.models.components.layer_norms import ( LayerNormConfig, @@ -19,7 +17,6 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method from modalities.util import parse_enum_by_name try: @@ -370,7 +367,6 @@ class GPT2LLMConfig(BaseModel): use_weight_tying: bool seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 - device_mesh: Optional[PydanticDeviceMeshIFType] = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -838,7 +834,6 @@ def __init__( use_weight_tying: bool, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, - device_mesh: DeviceMesh | None = None, ): """ Initializes the GPT2LLM object. @@ -867,18 +862,12 @@ def __init__( enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. - device_mesh (DeviceMesh | None): The device mesh for parallelism. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - # Set different random seed for each PP rank to ensure diversity - if seed is not None and has_parallelism_method( - device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP - ): - seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index ac3dca96b..5dc7986b2 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -26,7 +26,7 @@ class ActivationType(str, Enum): class NNModel(nn.Module): """NNModel class to define a base model.""" - def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None): + def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None): """ Initializes an NNModel object. diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index c4c953eaf..3acb17f95 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -578,7 +578,6 @@ def get_gpt2_model( use_meta_device: Optional[bool] = False, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, - device_mesh: DeviceMesh | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -602,7 +601,6 @@ def get_gpt2_model( seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, - device_mesh=device_mesh, ) if use_meta_device and use_weight_tying: raise ValueError( From bff99f3cb7880e34df303f8764ea8dcad1aaa49b Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 14:01:10 +0000 Subject: [PATCH 08/25] chore: Add seed and device_mesh to ComposedModelInitializationConfig --- .../nn/model_initialization/composed_initialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index b1b976573..1789011f1 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -5,7 +5,7 @@ from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Annotated -from modalities.config.pydantic_if_types import PydanticModelInitializationIFType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.nn.model_initialization.initialization_routines import InitializationRoutines from modalities.nn.model_initialization.parameter_name_filters import ( @@ -32,6 +32,8 @@ class ComposedModelInitializationConfig(BaseModel): std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None + seed: Optional[int] = None + device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces From 98ff9db1479c2da02c826bba8acf40946f677c62 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 14:09:18 +0000 Subject: [PATCH 09/25] test: Adapt test to latest changes --- .../config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml index fb8ee5f7d..8fe1d5472 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -129,7 +129,11 @@ initialized_model: weight_init_type: scaled mean: 0.0 std: 0.02 + seed: 42 num_layers: ${model_raw.config.n_layer} + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE model_raw: component_key: model From 2e248ed2477bfa64a3e91045da222a0cc4c86f35 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 14:22:47 +0000 Subject: [PATCH 10/25] chore: Remove old code --- .../test_parallel_seed_initialization.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index 45b79327f..b9bb2f7ca 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -1,7 +1,6 @@ import logging import multiprocessing as py_mp import os -import re import traceback from pathlib import Path from typing import Any @@ -15,10 +14,8 @@ from torch.distributed._tensor.placement_types import Replicate from modalities.__main__ import Main -from modalities.batch import EvaluationResultBatch from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType -from modalities.logging_broker.messages import Message from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import monitor_child_processes @@ -160,15 +157,3 @@ def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degre yaml.dump(config_dict, file) return temp_file_path - - -def _get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: - return [message.payload.losses[loss_key].value.item() for message in messages] - - -def _extract_seen_steps_and_tokens(filename: str) -> tuple[int, int]: - pattern = r"seen_steps_(\d+)-seen_tokens_(\d+)" - match = re.search(pattern, filename) - if match is None: - raise ValueError(f"Filename '{filename}' does not match expected pattern '{pattern}'.") - return int(match.group(1)), int(match.group(2)) From 5a9e89e9a7c829b77c8f2d339b4314e6ffaebc47 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 14:00:55 +0000 Subject: [PATCH 11/25] fix: Use local-generator weight init Co-authored-by: Copilot --- .../initialization_routines.py | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index c927453f6..4c3103fe7 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -40,21 +40,35 @@ class ScaledEmbedInitializationConfig(BaseModel): class NamedParameterwiseNormalInitialization(ModelInitializationIF): - def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter): + def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, seed: Optional[int] = None): self.mean = mean self.std = std self.parameter_name_regexes = parameter_name_regexes + self.seed = seed + self._generators: dict[str, torch.Generator] = {} + + def _get_generator(self, parameter: torch.Tensor) -> Optional[torch.Generator]: + if self.seed is None: + return None + + device_key = str(parameter.device) + generator = self._generators.get(device_key) + if generator is None: + generator = torch.Generator(device=parameter.device) + generator.manual_seed(self.seed) + self._generators[device_key] = generator + return generator def initialize_in_place(self, model: nn.Module): weight_regexes = self.parameter_name_regexes.weights - bias_regexes = self.parameter_name_regexes.biases + bias_regexes = self.parameter_name_regexes.biases or [] for parameter_name, p in model.named_parameters(): parameter_name = parameter_name.replace( "_orig_mod.", "" ) # remove FQN modification from torch.compile if present for weight_regex in weight_regexes: if re.fullmatch(weight_regex, parameter_name): - nn.init.normal_(p, mean=self.mean, std=self.std) + nn.init.normal_(p, mean=self.mean, std=self.std, generator=self._get_generator(p)) for bias_regex in bias_regexes: if re.fullmatch(bias_regex, parameter_name): nn.init.zeros_(p) @@ -65,7 +79,7 @@ class InitializationRoutines: def get_plain_initialization( mean: float, std: float | str, - parameter_name_regexes: list[str], + parameter_name_regexes: RegexFilter, hidden_dim: Optional[int] = None, seed: Optional[int] = None, ) -> NamedParameterwiseNormalInitialization: @@ -82,22 +96,22 @@ def get_plain_initialization( should be applied seed (Optional[int]): Random seed for initialization. Defaults to None. """ - InitializationRoutines._set_seed(seed) # auto: choose std automatically if std == "auto": if hidden_dim is None: raise ValueError("ERROR! weight_init.std = auto not implemented") # as per https://arxiv.org/abs/2312.16903 std = math.sqrt(2 / (5 * hidden_dim)) + assert isinstance(std, float) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=std, parameter_name_regexes=parameter_name_regexes + mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed ) return initialization @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: list[str], seed: Optional[int] = None + mean: float, std: float, num_layers: int, parameter_name_regexes: RegexFilter, seed: Optional[int] = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -112,18 +126,17 @@ def get_scaled_initialization( Returns: WeightInitializationIF: Weight initialization object """ - InitializationRoutines._set_seed(seed) # see https://arxiv.org/abs/2312.16903 scaled_std = std / math.sqrt(2 * num_layers) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes + mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes, seed=seed ) return initialization @staticmethod def get_scaled_embed_initialization( - mean: float, parameter_name_regexes: list[str], seed: Optional[int] = None + mean: float, parameter_name_regexes: RegexFilter, seed: Optional[int] = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -137,14 +150,8 @@ def get_scaled_embed_initialization( Returns: WeightInitializationIF: Weight initialization object """ - InitializationRoutines._set_seed(seed) std = math.sqrt(0.4) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=std, parameter_name_regexes=parameter_name_regexes + mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed ) return initialization - - @staticmethod - def _set_seed(seed: Optional[int]): - if seed is not None: - torch.manual_seed(seed) From 13e7a82952147442da730a9492d4f93c9d2f0175 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 14:04:19 +0000 Subject: [PATCH 12/25] refactor: Do not set seed in NNModel Co-authored-by: Copilot --- src/modalities/models/gpt2/gpt2_model.py | 6 +----- src/modalities/models/model.py | 5 +---- src/modalities/models/model_factory.py | 2 -- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 70f595e67..2da4979c0 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -342,7 +342,6 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network. lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head. use_weight_tying (bool): Whether to use weight tying. - seed: Optional[int] = None: The random seed for reproducibility. enforce_swiglu_hidden_dim_multiple_of (int): If specified, enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. @@ -370,7 +369,6 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config: LayerNormWrapperConfig lm_head_norm_config: LayerNormWrapperConfig use_weight_tying: bool - seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 @model_validator(mode="after") @@ -837,7 +835,6 @@ def __init__( ffn_norm_config: LayerNormWrapperConfig, lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, - seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, ): """ @@ -862,7 +859,6 @@ def __init__( attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module. ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. - seed (int, optional): The random seed. Defaults to None. use_weight_tying (bool): Whether to use weight tying. enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. @@ -873,7 +869,7 @@ def __init__( "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) + super().__init__(weight_decay_groups=weight_decay_groups) self.sample_key = sample_key self.prediction_key = prediction_key self.sequence_length = sequence_length diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index 5dc7986b2..f981f6117 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -26,16 +26,13 @@ class ActivationType(str, Enum): class NNModel(nn.Module): """NNModel class to define a base model.""" - def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None): + def __init__(self, weight_decay_groups: Optional[WeightDecayGroups] = None): """ Initializes an NNModel object. Args: - seed (int, optional): The seed value for random number generation. Defaults to None. weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None. """ - if seed is not None: - torch.manual_seed(seed) self._weight_decay_groups = weight_decay_groups if weight_decay_groups is not None else {} super(NNModel, self).__init__() diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 142aef920..62933794d 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -615,7 +615,6 @@ def get_gpt2_model( lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, use_meta_device: Optional[bool] = False, - seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, ) -> GPT2LLM: config = dict( @@ -637,7 +636,6 @@ def get_gpt2_model( attention_norm_config=attention_norm_config, ffn_norm_config=ffn_norm_config, lm_head_norm_config=lm_head_norm_config, - seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, ) From dc11bbb6fd975291d3f5ac3b0b1c41dcac0a37ed Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 14:33:57 +0000 Subject: [PATCH 13/25] docs: Add documentation and warning for topology-dependent weight initialization Co-authored-by: Copilot --- docs/components/components.md | 2 ++ .../composed_initialization.py | 33 +++++++++++++++++-- .../test_deferred_initialization.py | 1 - 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/components/components.md b/docs/components/components.md index 22f45958a..81af02d34 100644 --- a/docs/components/components.md +++ b/docs/components/components.md @@ -17,6 +17,8 @@ |---------------|--------------------|----------------|---------------|---------------------|-------------| | model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py)| [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place | +The composed initializer supports seeded weight initialization for reproducibility within a fixed topology. When pipeline parallelism is active, Modalities offsets the initialization seed by pipeline stage rank to avoid identical stage-local weights. As a result, the same seed can produce different initialized weights for different pipeline-parallel topologies. For topology-independent reproducibility, create and reuse a distributed checkpoint directly after weight initialization. + ## Losses |Component type | Component Version | Implementation | Configuration | Component Interface | Description | diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 1789011f1..c2d9cf9ee 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -1,5 +1,6 @@ from typing import Optional +import torch import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator from torch.distributed.device_mesh import DeviceMesh @@ -14,6 +15,9 @@ WeightInitTypes, ) from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) class ModelInitializerWrapperConfig(BaseModel): @@ -91,6 +95,24 @@ def initialize_in_place(self, model: nn.Module): class ComposedInitializationRoutines: + @staticmethod + def _warn_pp_topology_dependent_seed(device_mesh: Optional[DeviceMesh], seed: Optional[int]) -> None: + if seed is None or not has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + return + + if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: + return + + logger.warning( + "Seeded weight initialization is topology-dependent when pipeline parallelism is active. " + "Modalities offsets the initialization seed by PP rank to avoid identical stage-local weights, " + "so the same seed can produce different initialized weights for different PP configurations. " + "For topology-independent reproducibility, create and reuse a distributed checkpoint directly " + "after weight initialization." + ) + @staticmethod def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF: initializer_wrapper = ModelInitializerWrapper(model_initializers) @@ -103,7 +125,7 @@ def get_composed_model_initializer( mean: float, std: float | str, hidden_dim: Optional[int] = None, - num_layers: int = None, + num_layers: Optional[int] = None, device_mesh: Optional[DeviceMesh] = None, seed: Optional[int] = None, ) -> ModelInitializationIF: @@ -121,15 +143,21 @@ def get_composed_model_initializer( num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only). Defaults to None. device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization. - seed (Optional[int], optional): Seed for random initialization. Defaults to None. + seed (Optional[int], optional): Seed for random initialization. Defaults to None. When pipeline + parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local + initialization, so the same seed does not guarantee identical initialized weights across different + PP topologies. Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. """ + ComposedInitializationRoutines._warn_pp_topology_dependent_seed(device_mesh=device_mesh, seed=seed) + # Set different random seed for each PP rank to ensure diversity if seed is not None and has_parallelism_method( device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP ): + assert device_mesh is not None seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) model_initializers = [] @@ -148,6 +176,7 @@ def get_composed_model_initializer( if weight_init_type in [WeightInitTypes.SCALED, WeightInitTypes.SCALED_EMBED]: # scaled + assert num_layers is not None scaled_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED] scaled_init = InitializationRoutines.get_scaled_initialization( mean=mean, diff --git a/tests/nn/model_initialization/test_deferred_initialization.py b/tests/nn/model_initialization/test_deferred_initialization.py index 1c431abc4..eae9c0686 100644 --- a/tests/nn/model_initialization/test_deferred_initialization.py +++ b/tests/nn/model_initialization/test_deferred_initialization.py @@ -105,7 +105,6 @@ def _build_gpt2_model() -> GPT2LLM: ffn_norm_config=ln_cfg, lm_head_norm_config=ln_cfg, use_weight_tying=False, - seed=42, enforce_swiglu_hidden_dim_multiple_of=256, ) return model From 999cb6536e14029e033f8ed84fe25bebedb5bf08 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 14:45:11 +0000 Subject: [PATCH 14/25] fix: Fix transformers version mismatch Co-authored-by: Copilot --- src/modalities/config/config.py | 4 ++-- src/modalities/conversion/gpt2/modeling_gpt2.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 2f45a5f22..42a19b99a 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -7,8 +7,8 @@ from omegaconf import OmegaConf, Resolver from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator from torch.distributed.fsdp import ShardingStrategy -from transformers import GPT2TokenizerFast -from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast +from transformers import GPT2Tokenizer as GPT2TokenizerFast +from transformers import LlamaTokenizer as LlamaTokenizerFast from typing_extensions import deprecated from modalities.config.lookup_enum import LookupEnum diff --git a/src/modalities/conversion/gpt2/modeling_gpt2.py b/src/modalities/conversion/gpt2/modeling_gpt2.py index dec0bf64c..f6aa77ab1 100644 --- a/src/modalities/conversion/gpt2/modeling_gpt2.py +++ b/src/modalities/conversion/gpt2/modeling_gpt2.py @@ -40,7 +40,14 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from transformers.utils.generic import check_model_inputs + +try: + from transformers.utils.generic import check_model_inputs +except ImportError: + + def check_model_inputs(func: Callable) -> Callable: + return func + from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config From b02275f246d741221bff20f2cca0b84008792c41 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 15:01:51 +0000 Subject: [PATCH 15/25] test: Fix test by removing dependency on global RNG state for seed=None Co-authored-by: Copilot --- .../nn/model_initialization/initialization_routines.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 4c3103fe7..f995606b1 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -44,13 +44,10 @@ def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, self.mean = mean self.std = std self.parameter_name_regexes = parameter_name_regexes - self.seed = seed + self.seed = torch.initial_seed() if seed is None else seed self._generators: dict[str, torch.Generator] = {} - def _get_generator(self, parameter: torch.Tensor) -> Optional[torch.Generator]: - if self.seed is None: - return None - + def _get_generator(self, parameter: torch.Tensor) -> torch.Generator: device_key = str(parameter.device) generator = self._generators.get(device_key) if generator is None: From ddfbe4727702afedd3a0c4c4a04cb65271ea817d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 15:16:28 +0000 Subject: [PATCH 16/25] test: Adapt test to latest changes in main Co-authored-by: Copilot --- .../test_parallel_seed_initialization.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index b9bb2f7ca..f47b5b938 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -3,7 +3,7 @@ import os import traceback from pathlib import Path -from typing import Any +from typing import Any, cast import pytest import torch @@ -67,13 +67,18 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_ def _seed_distribution_impl(self, world_size: int, tmp_path: Path): # initialize components class ComponentsInstantiationModel(BaseModel): - fsdp_model: PydanticFSDP2ModuleType + fsdp_model: PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType] device_mesh: PydanticDeviceMeshIFType config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path) - main_obj = Main(config_file_path) - components = main_obj.build_components(components_model_type=ComponentsInstantiationModel) - model = components.fsdp_model + main_obj = Main(config_file_path, experiments_root_path=tmp_path) + components = cast( + ComponentsInstantiationModel, + main_obj.build_components(components_model_type=ComponentsInstantiationModel), + ) + model = cast( + Any, components.fsdp_model[0] if isinstance(components.fsdp_model, list) else components.fsdp_model + ) device_mesh = components.device_mesh # for each pp stage get first transformer block's MLP weight parameter shards and full tensor block_key = next(iter(model.transformer.h.keys())) @@ -89,12 +94,12 @@ class ComponentsInstantiationModel(BaseModel): "block_key": block_key, } - gather_list: list[dict[str, Any]] | None = [None] * world_size if dist.get_rank() == 0 else None + gather_list = cast(list[dict[str, Any] | None] | None, [None] * world_size if dist.get_rank() == 0 else None) dist.gather_object(payload, gather_list, dst=0) if dist.get_rank() == 0: assert gather_list is not None - TestParallelSeedInitialization._assert_parameter_distribution(gather_list) + TestParallelSeedInitialization._assert_parameter_distribution(cast(list[dict[str, Any]], gather_list)) dist.barrier() @staticmethod From 76762d9abf653e1a4038735e2dafd574a47dd2b6 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 18:51:00 +0000 Subject: [PATCH 17/25] chore: Use consistent typing for optional parameters --- .../composed_initialization.py | 8 ++++---- .../initialization_routines.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index c2d9cf9ee..67fb6956d 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -36,7 +36,7 @@ class ComposedModelInitializationConfig(BaseModel): std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None - seed: Optional[int] = None + seed: int | None = None device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see @@ -124,10 +124,10 @@ def get_composed_model_initializer( weight_init_type: WeightInitTypes, mean: float, std: float | str, - hidden_dim: Optional[int] = None, - num_layers: Optional[int] = None, + hidden_dim: int | None = None, + num_layers: int | None = None, device_mesh: Optional[DeviceMesh] = None, - seed: Optional[int] = None, + seed: int | None = None, ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index f995606b1..2aee5d300 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -1,6 +1,6 @@ import math import re -from typing import Annotated, Optional +from typing import Annotated import torch import torch.nn as nn @@ -14,7 +14,7 @@ class PlainInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" - hidden_dim: Optional[int] = None + hidden_dim: int | None = None @model_validator(mode="after") def check_std_and_hidden_dim(self): @@ -40,7 +40,7 @@ class ScaledEmbedInitializationConfig(BaseModel): class NamedParameterwiseNormalInitialization(ModelInitializationIF): - def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, seed: Optional[int] = None): + def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, seed: int | None = None): self.mean = mean self.std = std self.parameter_name_regexes = parameter_name_regexes @@ -77,8 +77,8 @@ def get_plain_initialization( mean: float, std: float | str, parameter_name_regexes: RegexFilter, - hidden_dim: Optional[int] = None, - seed: Optional[int] = None, + hidden_dim: int | None = None, + seed: int | None = None, ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -108,7 +108,7 @@ def get_plain_initialization( @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: RegexFilter, seed: Optional[int] = None + mean: float, std: float, num_layers: int, parameter_name_regexes: RegexFilter, seed: int | None = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -133,7 +133,7 @@ def get_scaled_initialization( @staticmethod def get_scaled_embed_initialization( - mean: float, parameter_name_regexes: RegexFilter, seed: Optional[int] = None + mean: float, parameter_name_regexes: RegexFilter, seed: int | None = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). From dea2eefe1ae6c9dc716979b5bec4e305cac1d12f Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 5 May 2026 19:00:58 +0000 Subject: [PATCH 18/25] chore: Remove outdated seed parameter --- .../end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml | 1 - tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml | 1 - tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml | 1 - tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml | 1 - .../end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml | 1 - .../configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml | 1 - .../configs/gpt2_warm_start_from_step_4_grad_accu.yaml | 1 - .../end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml | 1 - 8 files changed, 8 deletions(-) diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml index 395c131a7..9a9c886d4 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_grad_accu.yaml @@ -204,7 +204,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml index c784ae6bc..f31503a6f 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp.yaml @@ -270,7 +270,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml index eb6b5f490..a8f72ac2f 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_pp_tp.yaml @@ -281,7 +281,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml b/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml index 2d0c8e2b5..579162709 100644 --- a/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_train_num_steps_7_tp.yaml @@ -216,7 +216,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml index b5378e05d..e85c6e93c 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2.yaml @@ -223,7 +223,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml index 8af01a926..f6b553b44 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml @@ -223,7 +223,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml index c88c80922..4f073ec28 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_grad_accu.yaml @@ -223,7 +223,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml index 5b687e2e4..2029c0323 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_pp_tp.yaml @@ -300,7 +300,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key} From adf11f0c1b8ce6b2fe25cac521b5d77bc2546300 Mon Sep 17 00:00:00 2001 From: Richard Rutmann <97447451+rrutmann@users.noreply.github.com> Date: Thu, 7 May 2026 11:08:01 +0200 Subject: [PATCH 19/25] fix: Use correct type for parameter_name_regexes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Max Lübbering <2804731+le1nux@users.noreply.github.com> --- .../nn/model_initialization/initialization_routines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 2aee5d300..4e4cd1713 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -116,7 +116,7 @@ def get_scaled_initialization( mean (float): Mean of the normal distribution std (float): Standard deviation of the normal distribution used to initialize the other weights num_layers (int): Number of layers in the model which we use to downscale std with - parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization + parameter_name_regexes (RegexFilter): List of parameter name regexes to which the initialization should be applied seed (Optional[int]): Random seed for initialization. Defaults to None. From 4cf0032829cc965854e9ad6026e5e82964b74b2d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 7 May 2026 12:58:25 +0000 Subject: [PATCH 20/25] test: Add option for reliable vscode debugging --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9ea55a4d4..2ddc23935 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ line-length = 120 [tool.pytest.ini_options] addopts = "--cov=src --cov-report term --cov-report html" +#addopts = "-ra" # Enable this instead of line above for reliable VS Code test debugging (without coverage) [tool.coverage.run] branch = true From 7541df2931e9999ac829c8cced476fe1a8e2f6e8 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 7 May 2026 13:03:38 +0000 Subject: [PATCH 21/25] test: Add test for seeded model reproducibility Co-authored-by: Copilot --- .../test_parallel_seed_initialization.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index f47b5b938..b75b1d78e 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -12,6 +12,7 @@ import yaml from pydantic import BaseModel from torch.distributed._tensor.placement_types import Replicate +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType @@ -162,3 +163,72 @@ def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degre yaml.dump(config_dict, file) return temp_file_path + + +@pytest.mark.skipif( + torch.cuda.device_count() < 1, + reason="This test requires at least 1 GPU.", +) +class TestSeededModelReproducibility: + RDVZ_PORT = 24575 + + def test_same_seed_same_weights(self, tmp_path: Path): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=0, + local_rank=0, + world_size=1, + rdvz_port=TestSeededModelReproducibility.RDVZ_PORT, + ): + self._same_seed_same_weights_impl(tmp_path=tmp_path) + + def _same_seed_same_weights_impl(self, tmp_path: Path): + class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType + + config_file_path = self._get_tmp_seeded_config_path(tmp_path=tmp_path, seed=1234) + + main_obj_1 = Main(config_file_path, experiments_root_path=tmp_path) + components_1 = cast( + ComponentsInstantiationModel, + main_obj_1.build_components(components_model_type=ComponentsInstantiationModel), + ) + state_dict_1 = get_state_dict( + model=cast(Any, components_1.initialized_model), + optimizers=[], + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + )[0] + + main_obj_2 = Main(config_file_path, experiments_root_path=tmp_path) + components_2 = cast( + ComponentsInstantiationModel, + main_obj_2.build_components(components_model_type=ComponentsInstantiationModel), + ) + state_dict_2 = get_state_dict( + model=cast(Any, components_2.initialized_model), + optimizers=[], + options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True), + )[0] + + assert set(state_dict_1.keys()) == set(state_dict_2.keys()), "State dict keys differ between initializations" + for key in state_dict_1: + tensor_1 = state_dict_1[key] + tensor_2 = state_dict_2[key] + assert isinstance(tensor_1, torch.Tensor), f"Expected Tensor in first state dict for key {key}" + assert isinstance(tensor_2, torch.Tensor), f"Expected Tensor in second state dict for key {key}" + assert torch.equal(tensor_1, tensor_2), f"Mismatch for parameter {key}" + + dist.barrier() + + def _get_tmp_seeded_config_path(self, tmp_path: Path, seed: int) -> Path: + temp_file_path = tmp_path / "seeded_reproducibility.yaml" + config_file_path = Path(os.path.dirname(__file__)) / "../checkpointing/fsdp2_gpt2_config.yaml" + + with open(config_file_path, "r") as file: + config_dict = yaml.safe_load(file.read()) + config_dict["initialized_model"]["config"]["model_initializer"]["config"]["seed"] = seed + + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path From ede150ee11c6a17011503c244ac0ccdd7126730e Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 7 May 2026 15:35:05 +0000 Subject: [PATCH 22/25] chore: Change order of model initialization Co-authored-by: Copilot --- ...g_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml | 186 ++++++++++++++++++ .../test_parallel_seed_initialization.py | 10 +- 2 files changed, 193 insertions(+), 3 deletions(-) create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml new file mode 100644 index 000000000..ad6ed5954 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml @@ -0,0 +1,186 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + seed: 42 + num_layers: ${model_raw.config.n_layer} + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: Interleaved1F1B + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name} + num_layers_per_stage: 4 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index b75b1d78e..6a3333ace 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -68,7 +68,7 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_ def _seed_distribution_impl(self, world_size: int, tmp_path: Path): # initialize components class ComponentsInstantiationModel(BaseModel): - fsdp_model: PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType] + initialized_model: PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType] device_mesh: PydanticDeviceMeshIFType config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path) @@ -78,7 +78,10 @@ class ComponentsInstantiationModel(BaseModel): main_obj.build_components(components_model_type=ComponentsInstantiationModel), ) model = cast( - Any, components.fsdp_model[0] if isinstance(components.fsdp_model, list) else components.fsdp_model + Any, + components.initialized_model[0] + if isinstance(components.initialized_model, list) + else components.initialized_model, ) device_mesh = components.device_mesh # for each pp stage get first transformer block's MLP weight parameter shards and full tensor @@ -148,7 +151,8 @@ def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degre temp_file_path = tmp_path / "pp_tp_sharding_config.yaml" working_dir = Path(os.path.dirname(__file__)) config_file_path = ( - working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml" + working_dir + / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass_defer_init.yaml" ) with open(config_file_path, "r") as file: From 67bc5969c2f0ce82d9a96160ac966cdb561fd53c Mon Sep 17 00:00:00 2001 From: rrutmann Date: Thu, 7 May 2026 16:00:05 +0000 Subject: [PATCH 23/25] feat: Add multi_device_generator_policy for handling seeding with multiple devices per rank Co-authored-by: Copilot --- .../config_lorem_ipsum_long_fsdp2.yaml | 2 + .../composed_initialization.py | 10 +++- .../initialization_routines.py | 59 ++++++++++++++++--- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2.yaml index 87db96381..a51a58f32 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2.yaml @@ -13,6 +13,7 @@ settings: checkpoint_saving_path: data/checkpoints train_dataset_path: ./data/lorem_ipsum_long.pbin test_dataset_path: ./data/lorem_ipsum.pbin + experiments_root_path: ${modalities_env:experiments_root_path} intervals: training_log_interval_in_steps: 1 checkpointing_interval_in_steps: 32 @@ -221,6 +222,7 @@ initialized_model: mean: 0.0 std: 0.02 num_layers: ${model_raw.config.n_layer} + multi_device_generator_policy: error fsdp_model: component_key: model diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 67fb6956d..bd720d220 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional import torch import torch.nn as nn @@ -37,6 +37,7 @@ class ComposedModelInitializationConfig(BaseModel): hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None seed: int | None = None + multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn" device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see @@ -128,6 +129,7 @@ def get_composed_model_initializer( num_layers: int | None = None, device_mesh: Optional[DeviceMesh] = None, seed: int | None = None, + multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, @@ -147,6 +149,9 @@ def get_composed_model_initializer( parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local initialization, so the same seed does not guarantee identical initialized weights across different PP topologies. + multi_device_generator_policy (Literal["ignore", "warn", "error"], optional): Behavior when + initialization creates per-device RNG generators for more than one device in the same process. + Defaults to "warn". Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. @@ -170,6 +175,7 @@ def get_composed_model_initializer( hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes, seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) working_std = plain_init.std model_initializers.append(plain_init) @@ -184,6 +190,7 @@ def get_composed_model_initializer( num_layers=num_layers, parameter_name_regexes=scaled_parameter_name_regexes, seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) model_initializers.append(scaled_init) @@ -194,6 +201,7 @@ def get_composed_model_initializer( mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes, seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) model_initializers.append(scaled_embed_init) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 4e4cd1713..e71dc1be2 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -1,6 +1,7 @@ import math import re -from typing import Annotated +import warnings +from typing import Annotated, Literal import torch import torch.nn as nn @@ -40,17 +41,34 @@ class ScaledEmbedInitializationConfig(BaseModel): class NamedParameterwiseNormalInitialization(ModelInitializationIF): - def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, seed: int | None = None): + def __init__( + self, + mean: float, + std: float, + parameter_name_regexes: RegexFilter, + seed: int | None = None, + multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", + ): self.mean = mean self.std = std self.parameter_name_regexes = parameter_name_regexes self.seed = torch.initial_seed() if seed is None else seed + self.multi_device_generator_policy = multi_device_generator_policy self._generators: dict[str, torch.Generator] = {} def _get_generator(self, parameter: torch.Tensor) -> torch.Generator: device_key = str(parameter.device) generator = self._generators.get(device_key) if generator is None: + if len(self._generators) > 0: + message = ( + "NamedParameterwiseNormalInitialization created generators for multiple devices in one process " + f"(existing={list(self._generators.keys())}, new={device_key})." + ) + if self.multi_device_generator_policy == "error": + raise RuntimeError(message) + if self.multi_device_generator_policy == "warn": + warnings.warn(message, stacklevel=2) generator = torch.Generator(device=parameter.device) generator.manual_seed(self.seed) self._generators[device_key] = generator @@ -79,6 +97,7 @@ def get_plain_initialization( parameter_name_regexes: RegexFilter, hidden_dim: int | None = None, seed: int | None = None, + multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -92,6 +111,8 @@ def get_plain_initialization( parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied seed (Optional[int]): Random seed for initialization. Defaults to None. + multi_device_generator_policy (Literal["ignore", "warn", "error"]): Behavior when more than one + device-local RNG generator is created in the same process. """ # auto: choose std automatically if std == "auto": @@ -102,13 +123,22 @@ def get_plain_initialization( assert isinstance(std, float) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed + mean=mean, + std=std, + parameter_name_regexes=parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) return initialization @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: RegexFilter, seed: int | None = None + mean: float, + std: float, + num_layers: int, + parameter_name_regexes: RegexFilter, + seed: int | None = None, + multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -119,6 +149,8 @@ def get_scaled_initialization( parameter_name_regexes (RegexFilter): List of parameter name regexes to which the initialization should be applied seed (Optional[int]): Random seed for initialization. Defaults to None. + multi_device_generator_policy (Literal["ignore", "warn", "error"]): Behavior when more than one + device-local RNG generator is created in the same process. Returns: WeightInitializationIF: Weight initialization object @@ -127,13 +159,20 @@ def get_scaled_initialization( scaled_std = std / math.sqrt(2 * num_layers) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes, seed=seed + mean=mean, + std=scaled_std, + parameter_name_regexes=parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) return initialization @staticmethod def get_scaled_embed_initialization( - mean: float, parameter_name_regexes: RegexFilter, seed: int | None = None + mean: float, + parameter_name_regexes: RegexFilter, + seed: int | None = None, + multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -143,12 +182,18 @@ def get_scaled_embed_initialization( parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. seed (Optional[int]): Random seed for initialization. Defaults to None. + multi_device_generator_policy (Literal["ignore", "warn", "error"]): Behavior when more than one + device-local RNG generator is created in the same process. Returns: WeightInitializationIF: Weight initialization object """ std = math.sqrt(0.4) initialization = NamedParameterwiseNormalInitialization( - mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed + mean=mean, + std=std, + parameter_name_regexes=parameter_name_regexes, + seed=seed, + multi_device_generator_policy=multi_device_generator_policy, ) return initialization From 5172fc47e52d86ac1e6c87c865858e76428ccc15 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 8 May 2026 10:43:32 +0000 Subject: [PATCH 24/25] refactor: Use enum for multi_device_generator_policy Co-authored-by: Copilot --- .../composed_initialization.py | 15 ++++++----- .../initialization_routines.py | 27 ++++++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index bd720d220..e8e3e7114 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Optional import torch import torch.nn as nn @@ -8,7 +8,10 @@ from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType from modalities.nn.model_initialization.initialization_if import ModelInitializationIF -from modalities.nn.model_initialization.initialization_routines import InitializationRoutines +from modalities.nn.model_initialization.initialization_routines import ( + InitializationRoutines, + MultiDeviceGeneratorPolicy, +) from modalities.nn.model_initialization.parameter_name_filters import ( NAMED_PARAMETER_INIT_GROUPS, SupportWeightInitModels, @@ -37,7 +40,7 @@ class ComposedModelInitializationConfig(BaseModel): hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None seed: int | None = None - multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn" + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see @@ -129,7 +132,7 @@ def get_composed_model_initializer( num_layers: int | None = None, device_mesh: Optional[DeviceMesh] = None, seed: int | None = None, - multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, @@ -149,9 +152,9 @@ def get_composed_model_initializer( parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local initialization, so the same seed does not guarantee identical initialized weights across different PP topologies. - multi_device_generator_policy (Literal["ignore", "warn", "error"], optional): Behavior when + multi_device_generator_policy (MultiDeviceGeneratorPolicy, optional): Behavior when initialization creates per-device RNG generators for more than one device in the same process. - Defaults to "warn". + Defaults to MultiDeviceGeneratorPolicy.WARN. Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index e71dc1be2..1f785f562 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -1,7 +1,8 @@ import math import re import warnings -from typing import Annotated, Literal +from enum import Enum +from typing import Annotated import torch import torch.nn as nn @@ -11,6 +12,12 @@ from modalities.nn.model_initialization.parameter_name_filters import RegexFilter +class MultiDeviceGeneratorPolicy(str, Enum): + IGNORE = "ignore" + WARN = "warn" + ERROR = "error" + + class PlainInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" @@ -47,7 +54,7 @@ def __init__( std: float, parameter_name_regexes: RegexFilter, seed: int | None = None, - multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ): self.mean = mean self.std = std @@ -65,9 +72,9 @@ def _get_generator(self, parameter: torch.Tensor) -> torch.Generator: "NamedParameterwiseNormalInitialization created generators for multiple devices in one process " f"(existing={list(self._generators.keys())}, new={device_key})." ) - if self.multi_device_generator_policy == "error": + if self.multi_device_generator_policy == MultiDeviceGeneratorPolicy.ERROR: raise RuntimeError(message) - if self.multi_device_generator_policy == "warn": + if self.multi_device_generator_policy == MultiDeviceGeneratorPolicy.WARN: warnings.warn(message, stacklevel=2) generator = torch.Generator(device=parameter.device) generator.manual_seed(self.seed) @@ -97,7 +104,7 @@ def get_plain_initialization( parameter_name_regexes: RegexFilter, hidden_dim: int | None = None, seed: int | None = None, - multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -111,7 +118,7 @@ def get_plain_initialization( parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied seed (Optional[int]): Random seed for initialization. Defaults to None. - multi_device_generator_policy (Literal["ignore", "warn", "error"]): Behavior when more than one + multi_device_generator_policy (MultiDeviceGeneratorPolicy): Behavior when more than one device-local RNG generator is created in the same process. """ # auto: choose std automatically @@ -138,7 +145,7 @@ def get_scaled_initialization( num_layers: int, parameter_name_regexes: RegexFilter, seed: int | None = None, - multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -149,7 +156,7 @@ def get_scaled_initialization( parameter_name_regexes (RegexFilter): List of parameter name regexes to which the initialization should be applied seed (Optional[int]): Random seed for initialization. Defaults to None. - multi_device_generator_policy (Literal["ignore", "warn", "error"]): Behavior when more than one + multi_device_generator_policy (MultiDeviceGeneratorPolicy): Behavior when more than one device-local RNG generator is created in the same process. Returns: @@ -172,7 +179,7 @@ def get_scaled_embed_initialization( mean: float, parameter_name_regexes: RegexFilter, seed: int | None = None, - multi_device_generator_policy: Literal["ignore", "warn", "error"] = "warn", + multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN, ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -182,7 +189,7 @@ def get_scaled_embed_initialization( parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. seed (Optional[int]): Random seed for initialization. Defaults to None. - multi_device_generator_policy (Literal["ignore", "warn", "error"]): Behavior when more than one + multi_device_generator_policy (MultiDeviceGeneratorPolicy): Behavior when more than one device-local RNG generator is created in the same process. Returns: From 326823ec10d62785d1510161728c231e865c4046 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 8 May 2026 12:11:31 +0000 Subject: [PATCH 25/25] chore: Update model seed initialization Co-authored-by: Copilot --- .../training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml index b4982044c..8e44e38b8 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml @@ -223,6 +223,10 @@ initialized_model: mean: 0.0 std: 0.02 num_layers: ${model_raw.config.n_layer} + seed: 42 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE scheduled_pipeline: component_key: pipeline @@ -315,7 +319,6 @@ model_raw: component_key: model variant_key: gpt2 config: - seed: 42 use_meta_device: true use_weight_tying: false sample_key: ${settings.referencing_keys.sample_key}