11from typing import Optional
22
3+ import torch
34import torch .nn as nn
45from pydantic import BaseModel , ConfigDict , Field , model_validator
6+ from torch .distributed .device_mesh import DeviceMesh
57from typing_extensions import Annotated
68
7- from modalities .config .pydantic_if_types import PydanticModelInitializationIFType
9+ from modalities .config .pydantic_if_types import PydanticDeviceMeshIFType , PydanticModelInitializationIFType
810from modalities .nn .model_initialization .initialization_if import ModelInitializationIF
9- from modalities .nn .model_initialization .initialization_routines import InitializationRoutines
11+ from modalities .nn .model_initialization .initialization_routines import (
12+ InitializationRoutines ,
13+ MultiDeviceGeneratorPolicy ,
14+ )
1015from modalities .nn .model_initialization .parameter_name_filters import (
1116 NAMED_PARAMETER_INIT_GROUPS ,
1217 SupportWeightInitModels ,
1318 WeightInitTypes ,
1419)
20+ from modalities .running_env .fsdp .device_mesh import ParallelismDegrees , get_parallel_rank , has_parallelism_method
21+ from modalities .utils .logger_utils import get_logger
22+
23+ logger = get_logger (__name__ )
1524
1625
1726class ModelInitializerWrapperConfig (BaseModel ):
@@ -30,6 +39,9 @@ class ComposedModelInitializationConfig(BaseModel):
3039 std : Annotated [float , Field (strict = True , ge = 0.0 )] | str # can be float or "auto"
3140 hidden_dim : Optional [Annotated [int , Field (strict = True , gt = 0 )]] = None
3241 num_layers : Optional [Annotated [int , Field (strict = True , gt = 0 )]] = None
42+ seed : int | None = None
43+ multi_device_generator_policy : MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy .WARN
44+ device_mesh : Optional [PydanticDeviceMeshIFType ] = None
3345
3446 # avoid warning about protected namespace 'model_', see
3547 # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
@@ -87,6 +99,24 @@ def initialize_in_place(self, model: nn.Module):
8799
88100
89101class ComposedInitializationRoutines :
102+ @staticmethod
103+ def _warn_pp_topology_dependent_seed (device_mesh : Optional [DeviceMesh ], seed : Optional [int ]) -> None :
104+ if seed is None or not has_parallelism_method (
105+ device_mesh = device_mesh , parallelism_method = ParallelismDegrees .PP
106+ ):
107+ return
108+
109+ if torch .distributed .is_initialized () and torch .distributed .get_rank () != 0 :
110+ return
111+
112+ logger .warning (
113+ "Seeded weight initialization is topology-dependent when pipeline parallelism is active. "
114+ "Modalities offsets the initialization seed by PP rank to avoid identical stage-local weights, "
115+ "so the same seed can produce different initialized weights for different PP configurations. "
116+ "For topology-independent reproducibility, create and reuse a distributed checkpoint directly "
117+ "after weight initialization."
118+ )
119+
90120 @staticmethod
91121 def get_model_initializer_wrapper (model_initializers : list [ModelInitializationIF ]) -> ModelInitializationIF :
92122 initializer_wrapper = ModelInitializerWrapper (model_initializers )
@@ -98,8 +128,11 @@ def get_composed_model_initializer(
98128 weight_init_type : WeightInitTypes ,
99129 mean : float ,
100130 std : float | str ,
101- hidden_dim : Optional [int ] = None ,
102- num_layers : int = None ,
131+ hidden_dim : int | None = None ,
132+ num_layers : int | None = None ,
133+ device_mesh : Optional [DeviceMesh ] = None ,
134+ seed : int | None = None ,
135+ multi_device_generator_policy : MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy .WARN ,
103136 ) -> ModelInitializationIF :
104137 """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization.
105138 Note that plain initialization is always performed in the beginning. In case of scaled_embed,
@@ -114,36 +147,64 @@ def get_composed_model_initializer(
114147 Defaults to None.
115148 num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only).
116149 Defaults to None.
150+ device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization.
151+ seed (Optional[int], optional): Seed for random initialization. Defaults to None. When pipeline
152+ parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local
153+ initialization, so the same seed does not guarantee identical initialized weights across different
154+ PP topologies.
155+ multi_device_generator_policy (MultiDeviceGeneratorPolicy, optional): Behavior when
156+ initialization creates per-device RNG generators for more than one device in the same process.
157+ Defaults to MultiDeviceGeneratorPolicy.WARN.
117158
118159 Returns:
119160 ModelInitializationIF: The Weight Initializer performing the initialization as specified.
120161 """
162+ ComposedInitializationRoutines ._warn_pp_topology_dependent_seed (device_mesh = device_mesh , seed = seed )
163+
164+ # Set different random seed for each PP rank to ensure diversity
165+ if seed is not None and has_parallelism_method (
166+ device_mesh = device_mesh , parallelism_method = ParallelismDegrees .PP
167+ ):
168+ assert device_mesh is not None
169+ seed += get_parallel_rank (device_mesh = device_mesh , parallelism_method = ParallelismDegrees .PP )
170+
121171 model_initializers = []
122172
123173 # plain
124174 plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS [model_type ][WeightInitTypes .PLAIN ]
125175 plain_init = InitializationRoutines .get_plain_initialization (
126- mean = mean , std = std , hidden_dim = hidden_dim , parameter_name_regexes = plain_parameter_name_regexes
176+ mean = mean ,
177+ std = std ,
178+ hidden_dim = hidden_dim ,
179+ parameter_name_regexes = plain_parameter_name_regexes ,
180+ seed = seed ,
181+ multi_device_generator_policy = multi_device_generator_policy ,
127182 )
128183 working_std = plain_init .std
129184 model_initializers .append (plain_init )
130185
131186 if weight_init_type in [WeightInitTypes .SCALED , WeightInitTypes .SCALED_EMBED ]:
132187 # scaled
188+ assert num_layers is not None
133189 scaled_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS [model_type ][WeightInitTypes .SCALED ]
134190 scaled_init = InitializationRoutines .get_scaled_initialization (
135191 mean = mean ,
136192 std = working_std ,
137193 num_layers = num_layers ,
138194 parameter_name_regexes = scaled_parameter_name_regexes ,
195+ seed = seed ,
196+ multi_device_generator_policy = multi_device_generator_policy ,
139197 )
140198 model_initializers .append (scaled_init )
141199
142200 if weight_init_type == WeightInitTypes .SCALED_EMBED :
143201 # scaled embed
144202 scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS [model_type ][WeightInitTypes .SCALED_EMBED ]
145203 scaled_embed_init = InitializationRoutines .get_scaled_embed_initialization (
146- mean = mean , parameter_name_regexes = scaled_embed_parameter_name_regexes
204+ mean = mean ,
205+ parameter_name_regexes = scaled_embed_parameter_name_regexes ,
206+ seed = seed ,
207+ multi_device_generator_policy = multi_device_generator_policy ,
147208 )
148209 model_initializers .append (scaled_embed_init )
149210
0 commit comments