Skip to content

Commit 5ff89f0

Browse files
authored
Ensure parameters are initialized correctly on the meta device (#1368)
It looks like `reset_parameters` isn't enough to ensure that newly-initialized models using the torch.device("meta") semantics are identical to those created without it. This PR adds tests to ensure that these parameter mean and stdevs are identical, and uses ```python if args.use_meta_device: model.to_empty(device=device) model.apply(model._init_weights) ``` to ensure the weights are correctly initialized following `to_empty()`. --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 47bc53e commit 5ff89f0

17 files changed

Lines changed: 547 additions & 113 deletions

File tree

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,13 @@ class NVEsmPreTrainedModel(PreTrainedModel):
259259
"EsmEmbeddings",
260260
)
261261

262-
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
263262
def _init_weights(self, module: nn.Module):
264-
"""Initialize the weights.
263+
"""Initialize model weights.
264+
265+
This method ensures that models with randomly-initialized weights get the correct initial value distribution,
266+
which can be critical for training stability. We also call this method directly when using meta-device init, as
267+
the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
268+
we need to extend it to handle TE-specific modules.
265269
266270
Args:
267271
module (nn.Module): The module to initialize the weights for.
@@ -282,9 +286,29 @@ def _init_weights(self, module: nn.Module):
282286
module.bias.data.zero_()
283287
module.weight.data.fill_(1.0)
284288
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
289+
if module.layer_norm_bias is not None:
290+
module.layer_norm_bias.data.zero_()
285291
module.layer_norm_weight.data.fill_(1.0)
286292
if module.layer_norm_bias is not None:
287293
module.layer_norm_bias.data.zero_()
294+
if isinstance(module, transformer_engine.pytorch.LayerNormMLP):
295+
if module.layer_norm_bias is not None:
296+
module.layer_norm_bias.data.zero_()
297+
module.layer_norm_weight.data.fill_(1.0)
298+
if hasattr(module, "fc1_weight") and module.fc1_weight is not None:
299+
module.fc1_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
300+
if hasattr(module, "fc2_weight") and module.fc2_weight is not None:
301+
module.fc2_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
302+
if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0:
303+
module.fc1_bias.data.zero_()
304+
if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0:
305+
module.fc2_bias.data.zero_()
306+
if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"):
307+
# When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to
308+
# re-initialize it here with the correct values.
309+
module.inv_freq = RotaryPositionEmbedding(
310+
self.config.hidden_size // self.config.num_attention_heads
311+
).inv_freq.to(module.inv_freq.device)
288312

289313
@classmethod
290314
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Test that parameter distributions are identical with and without meta device initialization.
18+
19+
These tests verify that when using meta device initialization (creating the model on meta device, then calling
20+
`to_empty` and `_init_weights`), the resulting parameter distributions (mean and std) match those from normal
21+
initialization. This is important because we previously observed differences in convergence between meta-device-init and
22+
non-meta-device-init training, which suggested that the initialization was not being applied correctly after `to_empty`.
23+
By explicitly calling `_init_weights` after `to_empty`, we ensure that parameters are properly initialized, leading to
24+
consistent training behavior regardless of whether meta device initialization is used.
25+
"""
26+
27+
import os
28+
import subprocess
29+
30+
import pytest
31+
import torch
32+
from torch.distributed.fsdp import fully_shard
33+
from torch.distributed.tensor import DTensor
34+
from transformers import AutoConfig, set_seed
35+
36+
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
37+
38+
39+
requires_multi_gpu = pytest.mark.skipif(
40+
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
41+
reason="Test requires at least 2 GPUs",
42+
)
43+
44+
45+
def test_meta_device_init():
46+
config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict())
47+
48+
set_seed(42)
49+
with torch.device("meta"):
50+
model_meta_init = NVEsmForMaskedLM(config)
51+
52+
model_meta_init.to_empty(device="cuda")
53+
model_meta_init.apply(model_meta_init._init_weights)
54+
55+
set_seed(42)
56+
model_normal_init = NVEsmForMaskedLM(config)
57+
model_normal_init.to("cuda")
58+
59+
state_dict_meta_init = model_meta_init.state_dict()
60+
state_dict_normal_init = model_normal_init.state_dict()
61+
62+
for key in state_dict_meta_init.keys():
63+
meta_tensor = state_dict_meta_init[key]
64+
normal_tensor = state_dict_normal_init[key]
65+
# Skip non-numeric tensors (e.g., Byte/uint8 tensors like _extra_state)
66+
if meta_tensor.dtype not in (
67+
torch.float16,
68+
torch.float32,
69+
torch.float64,
70+
torch.bfloat16,
71+
torch.complex64,
72+
torch.complex128,
73+
):
74+
continue
75+
torch.testing.assert_close(
76+
normal_tensor.mean(),
77+
meta_tensor.mean(),
78+
atol=1e-3,
79+
rtol=1e-4,
80+
msg=lambda x: f"Mean mismatch for parameter {key}: {x}",
81+
)
82+
torch.testing.assert_close(
83+
normal_tensor.std(),
84+
meta_tensor.std(),
85+
atol=1e-3,
86+
rtol=1e-4,
87+
msg=lambda x: f"Std mismatch for parameter {key}: {x}",
88+
)
89+
90+
91+
@pytest.mark.parametrize("num_gpus", [1, pytest.param(2, marks=requires_multi_gpu)])
92+
def test_meta_device_init_after_fully_shard(num_gpus: int):
93+
cmd = [
94+
"torchrun",
95+
f"--nproc_per_node={num_gpus}",
96+
os.path.relpath(__file__),
97+
]
98+
99+
result = subprocess.run(
100+
cmd,
101+
check=False,
102+
text=True,
103+
stdout=subprocess.PIPE,
104+
stderr=subprocess.PIPE,
105+
timeout=240,
106+
)
107+
108+
if result.returncode != 0:
109+
print(f"STDOUT:\n{result.stdout}")
110+
print(f"STDERR:\n{result.stderr}")
111+
pytest.fail(f"Command failed with exit code {result.returncode}")
112+
113+
114+
if __name__ == "__main__":
115+
torch.distributed.init_process_group(backend="cuda:nccl")
116+
torch.cuda.set_device(torch.distributed.get_rank())
117+
118+
config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict())
119+
120+
set_seed(42)
121+
122+
with torch.device("meta"):
123+
model_meta_init = NVEsmForMaskedLM(config)
124+
125+
for layer in model_meta_init.esm.encoder.layers:
126+
fully_shard(layer)
127+
fully_shard(model_meta_init)
128+
129+
model_meta_init.to_empty(device="cuda")
130+
model_meta_init.apply(model_meta_init._init_weights)
131+
132+
set_seed(42)
133+
model_normal_init = NVEsmForMaskedLM(config)
134+
135+
for layer in model_normal_init.esm.encoder.layers:
136+
fully_shard(layer)
137+
fully_shard(model_normal_init)
138+
139+
state_dict_meta_init = model_meta_init.state_dict()
140+
state_dict_normal_init = model_normal_init.state_dict()
141+
142+
for key in state_dict_meta_init.keys():
143+
meta_tensor = state_dict_meta_init[key]
144+
normal_tensor = state_dict_normal_init[key]
145+
# Skip non-numeric tensors (e.g., Byte/uint8 tensors like _extra_state)
146+
if meta_tensor.dtype not in (
147+
torch.float16,
148+
torch.float32,
149+
torch.float64,
150+
torch.bfloat16,
151+
torch.complex64,
152+
torch.complex128,
153+
):
154+
continue
155+
156+
torch.testing.assert_close(
157+
normal_tensor.mean(),
158+
meta_tensor.mean(),
159+
atol=1e-3,
160+
rtol=1e-4,
161+
msg=lambda x: f"Mean mismatch for parameter {key}: {x}",
162+
)
163+
164+
if isinstance(normal_tensor, DTensor) and isinstance(meta_tensor, DTensor):
165+
torch.testing.assert_close(
166+
normal_tensor.full_tensor().std(),
167+
meta_tensor.full_tensor().std(),
168+
atol=1e-3,
169+
rtol=1e-4,
170+
msg=lambda x: f"Std mismatch for parameter {key}: {x}",
171+
)
172+
173+
else:
174+
torch.testing.assert_close(
175+
normal_tensor.std(),
176+
meta_tensor.std(),
177+
atol=1e-3,
178+
rtol=1e-4,
179+
msg=lambda x: f"Std mismatch for parameter {key}: {x}",
180+
)

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,22 @@ class NVLlamaConfig(LlamaConfig):
4848
class NVLlamaPreTrainedModel(PreTrainedModel):
4949
"""Base class for NVLlama models."""
5050

51-
config: NVLlamaConfig
51+
config_class = NVLlamaConfig
5252
base_model_prefix = "model"
5353
_no_split_modules = ("TransformerLayer",)
5454
_skip_keys_device_placement = ("past_key_values",)
5555

5656
def _init_weights(self, module):
57-
"""TE-specific weight initialization."""
57+
"""Initialize module weights.
58+
59+
This method ensures that models with randomly-initialized weights get the correct initial value distribution,
60+
which can be critical for training stability. We also call this method directly when using meta-device init, as
61+
the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
62+
we need to extend it to handle TE-specific modules.
63+
64+
Args:
65+
module (nn.Module): The module to initialize the weights for.
66+
"""
5867
super()._init_weights(module)
5968

6069
# Copied from transformers.modeling_utils.PreTrainedModel._init_weights
@@ -75,10 +84,25 @@ def _init_weights(self, module):
7584
module.weight.data.fill_(1.0)
7685
if hasattr(module, "bias") and module.bias is not None:
7786
module.bias.data.zero_()
87+
if isinstance(module, transformer_engine.pytorch.RMSNorm):
88+
if hasattr(module, "weight") and module.weight is not None:
89+
module.weight.data.fill_(1.0)
7890
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
7991
module.layer_norm_weight.data.fill_(1.0)
8092
if module.layer_norm_bias is not None:
8193
module.layer_norm_bias.data.zero_()
94+
if isinstance(module, transformer_engine.pytorch.LayerNormMLP):
95+
module.layer_norm_weight.data.fill_(1.0)
96+
if hasattr(module, "fc1_weight") and module.fc1_weight is not None:
97+
module.fc1_weight.data.normal_(mean=0.0, std=std)
98+
if hasattr(module, "fc2_weight") and module.fc2_weight is not None:
99+
module.fc2_weight.data.normal_(mean=0.0, std=std)
100+
if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0:
101+
module.fc1_bias.data.zero_()
102+
if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0:
103+
module.fc2_bias.data.zero_()
104+
if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"):
105+
module.inv_freq = LlamaRotaryEmbedding(config=self.config).inv_freq.to(module.inv_freq.device)
82106

83107

84108
class NVLlamaModel(NVLlamaPreTrainedModel):

0 commit comments

Comments
 (0)