Skip to content

Commit 4705675

Browse files
authored
Merge pull request #426 from Modalities/seed
fix: Diverse model seeding across PP ranks
2 parents b856127 + 326823e commit 4705675

23 files changed

Lines changed: 604 additions & 43 deletions

config_files/training/config_lorem_ipsum_long_fsdp2.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ settings:
1313
checkpoint_saving_path: data/checkpoints
1414
train_dataset_path: ./data/lorem_ipsum_long.pbin
1515
test_dataset_path: ./data/lorem_ipsum.pbin
16+
experiments_root_path: ${modalities_env:experiments_root_path}
1617
intervals:
1718
training_log_interval_in_steps: 1
1819
checkpointing_interval_in_steps: 32
@@ -221,6 +222,7 @@ initialized_model:
221222
mean: 0.0
222223
std: 0.02
223224
num_layers: ${model_raw.config.n_layer}
225+
multi_device_generator_policy: error
224226

225227
fsdp_model:
226228
component_key: model

config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ initialized_model:
223223
mean: 0.0
224224
std: 0.02
225225
num_layers: ${model_raw.config.n_layer}
226+
seed: 42
227+
device_mesh:
228+
instance_key: device_mesh
229+
pass_type: BY_REFERENCE
226230

227231
scheduled_pipeline:
228232
component_key: pipeline
@@ -315,7 +319,6 @@ model_raw:
315319
component_key: model
316320
variant_key: gpt2
317321
config:
318-
seed: 42
319322
use_meta_device: true
320323
use_weight_tying: false
321324
sample_key: ${settings.referencing_keys.sample_key}

docs/components/components.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
|---------------|--------------------|----------------|---------------|---------------------|-------------|
1818
| model_initialization | composed | [ComposedInitializationRoutines.get_composed_model_initializer](../../src/modalities/nn/model_initialization/composed_initialization.py)| [ComposedModelInitializationConfig](../../src/modalities/nn/model_initialization/composed_initialization.py) | [ModelInitializationIF](../../src/modalities/nn/model_initialization/initialization_if.py) | Component for initializing model weights in place |
1919

20+
The composed initializer supports seeded weight initialization for reproducibility within a fixed topology. When pipeline parallelism is active, Modalities offsets the initialization seed by pipeline stage rank to avoid identical stage-local weights. As a result, the same seed can produce different initialized weights for different pipeline-parallel topologies. For topology-independent reproducibility, create and reuse a distributed checkpoint directly after weight initialization.
21+
2022
## Losses
2123

2224
|Component type | Component Version | Implementation | Configuration | Component Interface | Description |

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ line-length = 120
124124

125125
[tool.pytest.ini_options]
126126
addopts = "--cov=src --cov-report term --cov-report html"
127+
#addopts = "-ra" # Enable this instead of line above for reliable VS Code test debugging (without coverage)
127128

128129
[tool.coverage.run]
129130
branch = true

src/modalities/config/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from omegaconf import OmegaConf, Resolver
88
from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator
99
from torch.distributed.fsdp import ShardingStrategy
10-
from transformers import GPT2TokenizerFast
11-
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
10+
from transformers import GPT2Tokenizer as GPT2TokenizerFast
11+
from transformers import LlamaTokenizer as LlamaTokenizerFast
1212
from typing_extensions import deprecated
1313

1414
from modalities.config.lookup_enum import LookupEnum

src/modalities/conversion/gpt2/modeling_gpt2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4141
from transformers.processing_utils import Unpack
4242
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
43-
from transformers.utils.generic import check_model_inputs
43+
44+
try:
45+
from transformers.utils.generic import check_model_inputs
46+
except ImportError:
47+
48+
def check_model_inputs(func: Callable) -> Callable:
49+
return func
50+
4451

4552
from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config
4653

src/modalities/models/gpt2/gpt2_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ class GPT2LLMConfig(BaseModel):
342342
ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network.
343343
lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head.
344344
use_weight_tying (bool): Whether to use weight tying.
345-
seed: Optional[int] = None: The random seed for reproducibility.
346345
enforce_swiglu_hidden_dim_multiple_of (int): If specified, enforces the hidden dimension
347346
in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the
348347
activation_type is SwiGLU. Defaults to 256.
@@ -370,7 +369,6 @@ class GPT2LLMConfig(BaseModel):
370369
ffn_norm_config: LayerNormWrapperConfig
371370
lm_head_norm_config: LayerNormWrapperConfig
372371
use_weight_tying: bool
373-
seed: Optional[int] = None
374372
enforce_swiglu_hidden_dim_multiple_of: int = 256
375373

376374
@model_validator(mode="after")
@@ -837,7 +835,6 @@ def __init__(
837835
ffn_norm_config: LayerNormWrapperConfig,
838836
lm_head_norm_config: LayerNormWrapperConfig,
839837
use_weight_tying: bool,
840-
seed: Optional[int] = None,
841838
enforce_swiglu_hidden_dim_multiple_of: int = 256,
842839
):
843840
"""
@@ -862,7 +859,6 @@ def __init__(
862859
attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module.
863860
ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module.
864861
lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module.
865-
seed (int, optional): The random seed. Defaults to None.
866862
use_weight_tying (bool): Whether to use weight tying.
867863
enforce_swiglu_hidden_dim_multiple_of (int): Enforces
868864
the hidden dimension in the SwiGLU layer to be a multiple of this value.
@@ -873,7 +869,7 @@ def __init__(
873869
"embedding": [".wte", ".wpe"],
874870
"layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"],
875871
}
876-
super().__init__(weight_decay_groups=weight_decay_groups, seed=seed)
872+
super().__init__(weight_decay_groups=weight_decay_groups)
877873
self.sample_key = sample_key
878874
self.prediction_key = prediction_key
879875
self.sequence_length = sequence_length

src/modalities/models/model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,13 @@ class ActivationType(str, Enum):
2626
class NNModel(nn.Module):
2727
"""NNModel class to define a base model."""
2828

29-
def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
29+
def __init__(self, weight_decay_groups: Optional[WeightDecayGroups] = None):
3030
"""
3131
Initializes an NNModel object.
3232
3333
Args:
34-
seed (int, optional): The seed value for random number generation. Defaults to None.
3534
weight_decay_groups (Optional[WeightDecayGroups], optional): The weight decay groups. Defaults to None.
3635
"""
37-
if seed is not None:
38-
torch.manual_seed(seed)
3936
self._weight_decay_groups = weight_decay_groups if weight_decay_groups is not None else {}
4037
super(NNModel, self).__init__()
4138

src/modalities/models/model_factory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,6 @@ def get_gpt2_model(
615615
lm_head_norm_config: LayerNormWrapperConfig,
616616
use_weight_tying: bool,
617617
use_meta_device: Optional[bool] = False,
618-
seed: Optional[int] = None,
619618
enforce_swiglu_hidden_dim_multiple_of: int = 256,
620619
) -> GPT2LLM:
621620
config = dict(
@@ -637,7 +636,6 @@ def get_gpt2_model(
637636
attention_norm_config=attention_norm_config,
638637
ffn_norm_config=ffn_norm_config,
639638
lm_head_norm_config=lm_head_norm_config,
640-
seed=seed,
641639
use_weight_tying=use_weight_tying,
642640
enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of,
643641
)

src/modalities/nn/model_initialization/composed_initialization.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
from typing import Optional
22

3+
import torch
34
import torch.nn as nn
45
from pydantic import BaseModel, ConfigDict, Field, model_validator
6+
from torch.distributed.device_mesh import DeviceMesh
57
from typing_extensions import Annotated
68

7-
from modalities.config.pydantic_if_types import PydanticModelInitializationIFType
9+
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType
810
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
9-
from modalities.nn.model_initialization.initialization_routines import InitializationRoutines
11+
from modalities.nn.model_initialization.initialization_routines import (
12+
InitializationRoutines,
13+
MultiDeviceGeneratorPolicy,
14+
)
1015
from modalities.nn.model_initialization.parameter_name_filters import (
1116
NAMED_PARAMETER_INIT_GROUPS,
1217
SupportWeightInitModels,
1318
WeightInitTypes,
1419
)
20+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method
21+
from modalities.utils.logger_utils import get_logger
22+
23+
logger = get_logger(__name__)
1524

1625

1726
class ModelInitializerWrapperConfig(BaseModel):
@@ -30,6 +39,9 @@ class ComposedModelInitializationConfig(BaseModel):
3039
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
3140
hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
3241
num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
42+
seed: int | None = None
43+
multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN
44+
device_mesh: Optional[PydanticDeviceMeshIFType] = None
3345

3446
# avoid warning about protected namespace 'model_', see
3547
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
@@ -87,6 +99,24 @@ def initialize_in_place(self, model: nn.Module):
8799

88100

89101
class ComposedInitializationRoutines:
102+
@staticmethod
103+
def _warn_pp_topology_dependent_seed(device_mesh: Optional[DeviceMesh], seed: Optional[int]) -> None:
104+
if seed is None or not has_parallelism_method(
105+
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
106+
):
107+
return
108+
109+
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
110+
return
111+
112+
logger.warning(
113+
"Seeded weight initialization is topology-dependent when pipeline parallelism is active. "
114+
"Modalities offsets the initialization seed by PP rank to avoid identical stage-local weights, "
115+
"so the same seed can produce different initialized weights for different PP configurations. "
116+
"For topology-independent reproducibility, create and reuse a distributed checkpoint directly "
117+
"after weight initialization."
118+
)
119+
90120
@staticmethod
91121
def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF:
92122
initializer_wrapper = ModelInitializerWrapper(model_initializers)
@@ -98,8 +128,11 @@ def get_composed_model_initializer(
98128
weight_init_type: WeightInitTypes,
99129
mean: float,
100130
std: float | str,
101-
hidden_dim: Optional[int] = None,
102-
num_layers: int = None,
131+
hidden_dim: int | None = None,
132+
num_layers: int | None = None,
133+
device_mesh: Optional[DeviceMesh] = None,
134+
seed: int | None = None,
135+
multi_device_generator_policy: MultiDeviceGeneratorPolicy = MultiDeviceGeneratorPolicy.WARN,
103136
) -> ModelInitializationIF:
104137
"""This initialization allows to intialize a model with plain, scaled or scaled_embed initialization.
105138
Note that plain initialization is always performed in the beginning. In case of scaled_embed,
@@ -114,36 +147,64 @@ def get_composed_model_initializer(
114147
Defaults to None.
115148
num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only).
116149
Defaults to None.
150+
device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization.
151+
seed (Optional[int], optional): Seed for random initialization. Defaults to None. When pipeline
152+
parallelism is active, the effective seed is offset by PP rank to avoid identical stage-local
153+
initialization, so the same seed does not guarantee identical initialized weights across different
154+
PP topologies.
155+
multi_device_generator_policy (MultiDeviceGeneratorPolicy, optional): Behavior when
156+
initialization creates per-device RNG generators for more than one device in the same process.
157+
Defaults to MultiDeviceGeneratorPolicy.WARN.
117158
118159
Returns:
119160
ModelInitializationIF: The Weight Initializer performing the initialization as specified.
120161
"""
162+
ComposedInitializationRoutines._warn_pp_topology_dependent_seed(device_mesh=device_mesh, seed=seed)
163+
164+
# Set different random seed for each PP rank to ensure diversity
165+
if seed is not None and has_parallelism_method(
166+
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
167+
):
168+
assert device_mesh is not None
169+
seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)
170+
121171
model_initializers = []
122172

123173
# plain
124174
plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN]
125175
plain_init = InitializationRoutines.get_plain_initialization(
126-
mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes
176+
mean=mean,
177+
std=std,
178+
hidden_dim=hidden_dim,
179+
parameter_name_regexes=plain_parameter_name_regexes,
180+
seed=seed,
181+
multi_device_generator_policy=multi_device_generator_policy,
127182
)
128183
working_std = plain_init.std
129184
model_initializers.append(plain_init)
130185

131186
if weight_init_type in [WeightInitTypes.SCALED, WeightInitTypes.SCALED_EMBED]:
132187
# scaled
188+
assert num_layers is not None
133189
scaled_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED]
134190
scaled_init = InitializationRoutines.get_scaled_initialization(
135191
mean=mean,
136192
std=working_std,
137193
num_layers=num_layers,
138194
parameter_name_regexes=scaled_parameter_name_regexes,
195+
seed=seed,
196+
multi_device_generator_policy=multi_device_generator_policy,
139197
)
140198
model_initializers.append(scaled_init)
141199

142200
if weight_init_type == WeightInitTypes.SCALED_EMBED:
143201
# scaled embed
144202
scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED]
145203
scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization(
146-
mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes
204+
mean=mean,
205+
parameter_name_regexes=scaled_embed_parameter_name_regexes,
206+
seed=seed,
207+
multi_device_generator_policy=multi_device_generator_policy,
147208
)
148209
model_initializers.append(scaled_embed_init)
149210

0 commit comments

Comments
 (0)