Skip to content

Commit 5f5e616

Browse files
committed
feat: implemented Llama3-like initialization for GPT2 models
1 parent e97578d commit 5f5e616

2 files changed

Lines changed: 113 additions & 0 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import math
2+
import re
3+
from functools import partial
4+
from typing import Annotated
5+
6+
import torch.nn as nn
7+
from pydantic import BaseModel, Field
8+
9+
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
10+
from modalities.utils.logger_utils import get_logger
11+
12+
logger = get_logger(name="llama3 initialization")
13+
14+
15+
class Llama3InitializerConfig(BaseModel):
16+
num_layers: Annotated[int, Field(strict=True, gt=0)]
17+
n_embd: Annotated[int, Field(strict=True, gt=0)]
18+
19+
20+
class Llama3Initializer(ModelInitializationIF):
21+
"""
22+
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
23+
"""
24+
25+
def __init__(self, num_layers: int, n_embd: int) -> None:
26+
super().__init__()
27+
28+
self.regex_to_init = {
29+
# embedding weights
30+
r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1),
31+
r"transformer\.wpe\.weight": partial(nn.init.normal_, mean=0.0, std=1),
32+
# lm head weights
33+
r"transformer\.lm_head\.weight": partial(
34+
nn.init.trunc_normal_,
35+
mean=0.0,
36+
std=1 / math.sqrt(n_embd),
37+
a=-3 / math.sqrt(n_embd),
38+
b=3 / math.sqrt(n_embd),
39+
),
40+
# qkv projections
41+
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": partial(
42+
nn.init.trunc_normal_,
43+
mean=0.0,
44+
std=0.02,
45+
a=-2,
46+
b=2,
47+
),
48+
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial(
49+
nn.init.trunc_normal_,
50+
mean=0.0,
51+
std=0.02,
52+
a=-2,
53+
b=2,
54+
),
55+
# final attention projection in attention block
56+
r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial(
57+
nn.init.trunc_normal_,
58+
mean=0.0,
59+
std=0.02 / math.sqrt(2 * num_layers),
60+
a=-2,
61+
b=2,
62+
),
63+
r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial(
64+
nn.init.trunc_normal_,
65+
mean=0.0,
66+
std=0.02 / math.sqrt(2 * num_layers),
67+
a=-2,
68+
b=2,
69+
),
70+
# SwiGLU
71+
r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial(
72+
nn.init.trunc_normal_,
73+
mean=0.0,
74+
std=0.02,
75+
a=-2,
76+
b=2,
77+
),
78+
r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_,
79+
r"transformer\.h\.\w+\.mlp\.(V|W_2)\.weight": partial(
80+
nn.init.trunc_normal_,
81+
mean=0.0,
82+
std=0.02 / math.sqrt(2 * num_layers),
83+
a=-2,
84+
b=2,
85+
),
86+
r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_,
87+
}
88+
89+
def initialize_in_place(self, model: nn.Module):
90+
self._init_by_fqn_regex(model, self.regex_to_init)
91+
92+
@staticmethod
93+
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]):
94+
for parameter_name, p in model.named_parameters():
95+
match_count = 0
96+
for weight_regex in regex_to_init.keys():
97+
if re.fullmatch(weight_regex, parameter_name):
98+
init_fn = regex_to_init[weight_regex]
99+
init_fn(p)
100+
match_count += 1
101+
if match_count == 0:
102+
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
103+
elif match_count > 1:
104+
raise ValueError(
105+
f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed"
106+
)

src/modalities/registry/components.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
)
9393
from modalities.models.gpt2.collator import GPT2LLMCollateFn
9494
from modalities.models.gpt2.gpt2_model import GPT2LLMConfig
95+
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig
9596
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
9697
from modalities.models.model_factory import GPT2ModelFactory, ModelFactory
9798
from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory
@@ -240,6 +241,12 @@ class ComponentEntity:
240241
ComposedInitializationRoutines.get_composed_model_initializer,
241242
ComposedModelInitializationConfig,
242243
),
244+
ComponentEntity(
245+
"model_initialization",
246+
"llama3_like",
247+
Llama3Initializer,
248+
Llama3InitializerConfig,
249+
),
243250
# losses
244251
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
245252
# optimizers

0 commit comments

Comments
 (0)