-
Notifications
You must be signed in to change notification settings - Fork 16
fix: Diverse model seeding across PP ranks #426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
5f9f50e
8c8c5ab
ab3daa0
62a1743
00a595b
bf06da7
b137701
bff99f3
98ff9db
2e248ed
093fa33
5a9e89e
13e7a82
dc11bbb
999cb65
b02275f
ddfbe47
76762d9
dea2eef
adf11f0
4cf0032
7541df2
ede150e
67bc596
5172fc4
326823e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,8 @@ | ||
| import math | ||
| import re | ||
| from typing import Annotated, Optional | ||
| from typing import Annotated | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from pydantic import BaseModel, Field, model_validator | ||
|
|
||
|
|
@@ -13,7 +14,7 @@ class PlainInitializationConfig(BaseModel): | |
| mean: float | ||
| std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" | ||
| parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" | ||
| hidden_dim: Optional[int] = None | ||
| hidden_dim: int | None = None | ||
|
|
||
| @model_validator(mode="after") | ||
| def check_std_and_hidden_dim(self): | ||
|
|
@@ -39,21 +40,32 @@ class ScaledEmbedInitializationConfig(BaseModel): | |
|
|
||
|
|
||
| class NamedParameterwiseNormalInitialization(ModelInitializationIF): | ||
| def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter): | ||
| def __init__(self, mean: float, std: float, parameter_name_regexes: RegexFilter, seed: int | None = None): | ||
| self.mean = mean | ||
| self.std = std | ||
| self.parameter_name_regexes = parameter_name_regexes | ||
| self.seed = torch.initial_seed() if seed is None else seed | ||
| self._generators: dict[str, torch.Generator] = {} | ||
|
|
||
| def _get_generator(self, parameter: torch.Tensor) -> torch.Generator: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a few things are not clear to me.
I'm not sure what the best way to solve this ... also seems to me that the Pytorch API regarding Generators is kinda limited.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we start a single process for each rank via torchrun, this shouldn't happen, right? Or do I miss something? |
||
| device_key = str(parameter.device) | ||
| generator = self._generators.get(device_key) | ||
| if generator is None: | ||
| generator = torch.Generator(device=parameter.device) | ||
| generator.manual_seed(self.seed) | ||
| self._generators[device_key] = generator | ||
| return generator | ||
|
|
||
| def initialize_in_place(self, model: nn.Module): | ||
| weight_regexes = self.parameter_name_regexes.weights | ||
| bias_regexes = self.parameter_name_regexes.biases | ||
| bias_regexes = self.parameter_name_regexes.biases or [] | ||
| for parameter_name, p in model.named_parameters(): | ||
| parameter_name = parameter_name.replace( | ||
| "_orig_mod.", "" | ||
| ) # remove FQN modification from torch.compile if present | ||
| for weight_regex in weight_regexes: | ||
| if re.fullmatch(weight_regex, parameter_name): | ||
| nn.init.normal_(p, mean=self.mean, std=self.std) | ||
| nn.init.normal_(p, mean=self.mean, std=self.std, generator=self._get_generator(p)) | ||
| for bias_regex in bias_regexes: | ||
| if re.fullmatch(bias_regex, parameter_name): | ||
| nn.init.zeros_(p) | ||
|
|
@@ -62,7 +74,11 @@ def initialize_in_place(self, model: nn.Module): | |
| class InitializationRoutines: | ||
| @staticmethod | ||
| def get_plain_initialization( | ||
| mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None | ||
| mean: float, | ||
| std: float | str, | ||
| parameter_name_regexes: RegexFilter, | ||
| hidden_dim: int | None = None, | ||
| seed: int | None = None, | ||
| ) -> NamedParameterwiseNormalInitialization: | ||
| """Initializes the weights of a model by sampling from a normal distribution. | ||
| NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. | ||
|
|
@@ -73,23 +89,26 @@ def get_plain_initialization( | |
| std (float): standard deviation of the normal distribution. If set to "auto", appropiate | ||
| value selected as per plain initialization described in https://arxiv.org/abs/2312.16903 | ||
| hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None. | ||
| parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization | ||
|
rrutmann marked this conversation as resolved.
|
||
| should be applied | ||
| seed (Optional[int]): Random seed for initialization. Defaults to None. | ||
| """ | ||
|
|
||
| # auto: choose std automatically | ||
| if std == "auto": | ||
| if hidden_dim is None: | ||
| raise ValueError("ERROR! weight_init.std = auto not implemented") | ||
| # as per https://arxiv.org/abs/2312.16903 | ||
| std = math.sqrt(2 / (5 * hidden_dim)) | ||
| assert isinstance(std, float) | ||
|
|
||
| initialization = NamedParameterwiseNormalInitialization( | ||
| mean=mean, std=std, parameter_name_regexes=parameter_name_regexes | ||
| mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed | ||
| ) | ||
| return initialization | ||
|
|
||
| @staticmethod | ||
| def get_scaled_initialization( | ||
| mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] | ||
| mean: float, std: float, num_layers: int, parameter_name_regexes: RegexFilter, seed: int | None = None | ||
| ) -> ModelInitializationIF: | ||
| """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 | ||
|
|
||
|
|
@@ -99,6 +118,7 @@ def get_scaled_initialization( | |
| num_layers (int): Number of layers in the model which we use to downscale std with | ||
| parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization | ||
|
rrutmann marked this conversation as resolved.
Outdated
|
||
| should be applied | ||
| seed (Optional[int]): Random seed for initialization. Defaults to None. | ||
|
|
||
| Returns: | ||
| WeightInitializationIF: Weight initialization object | ||
|
|
@@ -107,25 +127,28 @@ def get_scaled_initialization( | |
| scaled_std = std / math.sqrt(2 * num_layers) | ||
|
|
||
| initialization = NamedParameterwiseNormalInitialization( | ||
| mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes | ||
| mean=mean, std=scaled_std, parameter_name_regexes=parameter_name_regexes, seed=seed | ||
| ) | ||
| return initialization | ||
|
|
||
| @staticmethod | ||
| def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: | ||
| def get_scaled_embed_initialization( | ||
| mean: float, parameter_name_regexes: RegexFilter, seed: int | None = None | ||
| ) -> ModelInitializationIF: | ||
| """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 | ||
| We fix the standard deviation to sqrt(0.4). | ||
|
|
||
| Args: | ||
| mean (float): Mean of the normal distribution | ||
| parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization | ||
| should be applied Defaults to None. | ||
| seed (Optional[int]): Random seed for initialization. Defaults to None. | ||
|
|
||
| Returns: | ||
| WeightInitializationIF: Weight initialization object | ||
| """ | ||
| std = math.sqrt(0.4) | ||
| initialization = NamedParameterwiseNormalInitialization( | ||
| mean=mean, std=std, parameter_name_regexes=parameter_name_regexes | ||
| mean=mean, std=std, parameter_name_regexes=parameter_name_regexes, seed=seed | ||
| ) | ||
| return initialization | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this removed in transormers?
If it is part of a legacy API I think we should also remove this on our end.
What do you think @BlueCrescent? I think you added it, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function was removed in transformers version 5.2. In our pyproject.yaml we specify the requirement "transformers>=4.57.4,<5.0.0", so I used an unsupported transformers version here. Should we remove it just to be on the safe side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think, we should tackle the transformers 5.0.0+ support soon anyways.