Skip to content

Commit aeed0de

Browse files
authored
Make sure we pass device="meta" to TE layers during meta device init (#1397)
TransformerEngine requires that we pass `device="meta"` (and importantly _not_ `device=torch.device("meta")`) to layer constructors to initialize parameters on the meta device. This makes sure we pass the right device to the layer constructor and adds tests to ensure the parameters are actually being placed on the right devices. For TransformerEngine layers, we want to be moving parameters from the meta device to cuda with ```python for module in model.modules(): if hasattr(module, "reset_parameters"): module.reset_parameters() ``` while for HF layers (e.g., nn.Embedding), we want to be doing ```python model.to_empty("cuda") model.apply(model._init_weights) ``` to ensure that we pick up the config.initializer_range initialization correctly. The issue is that we can't do `to_empty("cuda")` or `_init_weights` on TE layers, nor can we do `reset_parameters()` on the HF layers without the preceeding `to_empty`, and this doesn't use the HF config when creating initial values. --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent bf566b9 commit aeed0de

19 files changed

Lines changed: 888 additions & 534 deletions

File tree

.devcontainer/recipes/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ accelerate
22
datasets
33
deepspeed
44
hydra-core
5+
lm-eval
56
# TOT megatron-mfsdp until NVIDIA/Megatron-LM#2575 is in a release.
67
megatron-fsdp @ git+https://github.com/NVIDIA/Megatron-LM.git@main#subdirectory=megatron/core/distributed/fsdp/src
78
peft

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# noqa: license-check
32
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
43
# SPDX-License-Identifier: LicenseRef-Apache2
@@ -38,9 +37,8 @@
3837
MaskedLMOutput,
3938
TokenClassifierOutput,
4039
)
41-
from transformers.modeling_utils import PreTrainedModel
4240
from transformers.models.esm.configuration_esm import EsmConfig
43-
from transformers.models.esm.modeling_esm import EsmPooler
41+
from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel
4442
from transformers.utils import logging
4543
from transformers.utils.generic import TransformersKwargs
4644

@@ -135,6 +133,10 @@ def __init__(self, config: NVEsmConfig):
135133
"""
136134
super().__init__()
137135
self.config = config
136+
137+
def _init_method(x):
138+
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
139+
138140
self.layers = nn.ModuleList(
139141
[
140142
transformer_engine.pytorch.TransformerLayer(
@@ -156,12 +158,18 @@ def __init__(self, config: NVEsmConfig):
156158
fuse_qkv_params=config.fuse_qkv_params,
157159
params_dtype=config.dtype,
158160
window_size=(-1, -1),
161+
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
162+
init_method=_init_method,
163+
output_layer_init_method=_init_method,
159164
)
160165
for i in range(config.num_hidden_layers)
161166
]
162167
)
163168
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
164-
config.hidden_size, eps=config.layer_norm_eps, params_dtype=config.dtype
169+
config.hidden_size,
170+
eps=config.layer_norm_eps,
171+
params_dtype=config.dtype,
172+
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
165173
)
166174
if config.position_embedding_type == "rotary":
167175
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
@@ -247,7 +255,7 @@ def forward(
247255
)
248256

249257

250-
class NVEsmPreTrainedModel(PreTrainedModel):
258+
class NVEsmPreTrainedModel(EsmPreTrainedModel):
251259
"""An abstract class to handle weights initialization and pretrained model loading."""
252260

253261
config_class = NVEsmConfig
@@ -259,61 +267,22 @@ class NVEsmPreTrainedModel(PreTrainedModel):
259267
"EsmEmbeddings",
260268
)
261269

262-
def _init_weights(self, module: nn.Module):
263-
"""Initialize model weights.
270+
def init_empty_weights(self):
271+
"""Handles moving the model from the meta device to the cuda device and initializing the weights."""
272+
# For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight
273+
# initialization we passed them during module creation.
274+
for module in self.modules():
275+
if hasattr(module, "reset_parameters"):
276+
module.reset_parameters()
264277

265-
This method ensures that models with randomly-initialized weights get the correct initial value distribution,
266-
which can be critical for training stability. We also call this method directly when using meta-device init, as
267-
the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
268-
we need to extend it to handle TE-specific modules.
278+
# The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
279+
# `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
280+
# deviation.
281+
self.esm.embeddings.word_embeddings.to_empty(device="cuda")
282+
self.esm.embeddings.apply(self._init_weights)
269283

270-
Args:
271-
module (nn.Module): The module to initialize the weights for.
272-
"""
273-
if isinstance(
274-
module, (nn.Linear, transformer_engine.pytorch.Linear, transformer_engine.pytorch.LayerNormLinear)
275-
):
276-
# Slightly different from the TF version which uses truncated_normal for initialization
277-
# cf https://github.com/pytorch/pytorch/pull/5617
278-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
279-
if module.bias is not None:
280-
module.bias.data.zero_()
281-
if isinstance(module, nn.Embedding):
282-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
283-
if module.padding_idx is not None:
284-
module.weight.data[module.padding_idx].zero_()
285-
if isinstance(module, (nn.LayerNorm, transformer_engine.pytorch.LayerNorm)):
286-
module.bias.data.zero_()
287-
module.weight.data.fill_(1.0)
288-
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
289-
if module.layer_norm_bias is not None:
290-
module.layer_norm_bias.data.zero_()
291-
module.layer_norm_weight.data.fill_(1.0)
292-
if module.layer_norm_bias is not None:
293-
module.layer_norm_bias.data.zero_()
294-
if isinstance(module, transformer_engine.pytorch.LayerNormMLP):
295-
if module.layer_norm_bias is not None:
296-
module.layer_norm_bias.data.zero_()
297-
module.layer_norm_weight.data.fill_(1.0)
298-
if hasattr(module, "fc1_weight") and module.fc1_weight is not None:
299-
module.fc1_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
300-
if hasattr(module, "fc2_weight") and module.fc2_weight is not None:
301-
module.fc2_weight.data.normal_(mean=0.0, std=self.config.initializer_range)
302-
if hasattr(module, "fc1_bias") and module.fc1_bias is not None and module.fc1_bias.numel() > 0:
303-
module.fc1_bias.data.zero_()
304-
if hasattr(module, "fc2_bias") and module.fc2_bias is not None and module.fc2_bias.numel() > 0:
305-
module.fc2_bias.data.zero_()
306-
if isinstance(module, RotaryPositionEmbedding) and hasattr(module, "inv_freq"):
307-
# When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to
308-
# re-initialize it here with the correct values.
309-
module.inv_freq = RotaryPositionEmbedding(
310-
self.config.hidden_size // self.config.num_attention_heads
311-
).inv_freq.to(module.inv_freq.device)
312-
313-
@classmethod
314-
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
315-
"""Override the default get_init_context method to allow for fp8 model initialization."""
316-
return []
284+
# Meta-device init seems to break weight tying, so we re-tie the weights here.
285+
self.tie_weights()
317286

318287

319288
class NVEsmModel(NVEsmPreTrainedModel):
@@ -516,15 +485,20 @@ def __init__(self, config: NVEsmConfig):
516485
config.hidden_size,
517486
config.hidden_size,
518487
params_dtype=config.dtype,
488+
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
489+
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
519490
)
520491

521-
self.decoder = transformer_engine.pytorch.LayerNormLinear(
522-
config.hidden_size,
523-
config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
524-
bias=True,
525-
eps=config.layer_norm_eps,
526-
params_dtype=config.dtype,
527-
)
492+
with transformer_engine.pytorch.fp8_model_init(enabled=False):
493+
self.decoder = transformer_engine.pytorch.LayerNormLinear(
494+
config.hidden_size,
495+
config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size,
496+
bias=True,
497+
eps=config.layer_norm_eps,
498+
params_dtype=config.dtype,
499+
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
500+
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
501+
)
528502

529503
def forward(self, features, **kwargs):
530504
"""Forward pass of the NVEsmLMHead.
@@ -553,7 +527,12 @@ def __init__(self, config):
553527
)
554528

555529
self.layer_norm = (
556-
transformer_engine.pytorch.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
530+
transformer_engine.pytorch.LayerNorm(
531+
config.hidden_size,
532+
eps=config.layer_norm_eps,
533+
params_dtype=config.dtype,
534+
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
535+
)
557536
if config.emb_layer_norm_before
558537
else None
559538
)
@@ -648,7 +627,11 @@ def __init__(self, config):
648627
self.esm = NVEsmModel(config, add_pooling_layer=False)
649628
self.dropout = nn.Dropout(config.hidden_dropout_prob)
650629
self.classifier = transformer_engine.pytorch.Linear(
651-
config.hidden_size, config.num_labels, params_dtype=config.dtype
630+
config.hidden_size,
631+
config.num_labels,
632+
params_dtype=config.dtype,
633+
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
634+
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
652635
)
653636

654637
self.init_weights()

bionemo-recipes/models/esm2/tests/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import pytest
2020
import transformer_engine.pytorch
21+
from transformer_engine.common import recipe as recipe_module
22+
from transformer_engine.pytorch import fp8
2123
from transformers import AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling
2224

2325
from esm.convert import convert_esm_hf_to_te
@@ -88,3 +90,63 @@ def te_model_checkpoint(tmp_path):
8890
model_te = convert_esm_hf_to_te(model_hf)
8991
model_te.save_pretrained(tmp_path / "te_model_checkpoint")
9092
return tmp_path / "te_model_checkpoint"
93+
94+
95+
ALL_RECIPES = [
96+
recipe_module.DelayedScaling(),
97+
recipe_module.Float8CurrentScaling(),
98+
recipe_module.Float8BlockScaling(),
99+
recipe_module.MXFP8BlockScaling(),
100+
# recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True),
101+
]
102+
103+
104+
def _check_recipe_support(recipe: recipe_module.Recipe):
105+
"""Check if a recipe is supported and return (supported, reason)."""
106+
if isinstance(recipe, recipe_module.DelayedScaling):
107+
recipe_supported, reason = fp8.check_fp8_support()
108+
elif isinstance(recipe, recipe_module.Float8CurrentScaling):
109+
recipe_supported, reason = fp8.check_fp8_support()
110+
elif isinstance(recipe, recipe_module.Float8BlockScaling):
111+
recipe_supported, reason = fp8.check_fp8_block_scaling_support()
112+
elif isinstance(recipe, recipe_module.MXFP8BlockScaling):
113+
recipe_supported, reason = fp8.check_mxfp8_support()
114+
elif isinstance(recipe, recipe_module.NVFP4BlockScaling):
115+
recipe_supported, reason = fp8.check_nvfp4_support()
116+
else:
117+
recipe_supported = False
118+
reason = "Unsupported recipe"
119+
return recipe_supported, reason
120+
121+
122+
def requires_recipe_support(recipe: recipe_module.Recipe):
123+
"""Decorator to skip tests that require recipe support."""
124+
125+
def requires_recipe_support_inner(func):
126+
recipe_supported, reason = _check_recipe_support(recipe)
127+
return pytest.mark.skipif(not recipe_supported, reason=reason)(func)
128+
129+
return requires_recipe_support_inner
130+
131+
132+
def parametrize_recipes_with_support(recipes):
133+
"""Generate pytest.param objects with skip marks for unsupported recipes."""
134+
parametrized_recipes = []
135+
for recipe in recipes:
136+
recipe_supported, reason = _check_recipe_support(recipe)
137+
parametrized_recipes.append(
138+
pytest.param(
139+
recipe,
140+
id=recipe.__class__.__name__,
141+
marks=pytest.mark.skipif(
142+
not recipe_supported,
143+
reason=reason,
144+
),
145+
)
146+
)
147+
return parametrized_recipes
148+
149+
150+
@pytest.fixture(params=parametrize_recipes_with_support(ALL_RECIPES))
151+
def fp8_recipe(request):
152+
return request.param

bionemo-recipes/models/esm2/tests/test_convert.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,45 @@ def test_padding_unpadding_operations():
135135
if te_embeddings.shape[0] > original_embeddings.shape[0]:
136136
padding_rows = te_embeddings[original_embeddings.shape[0] :]
137137
torch.testing.assert_close(padding_rows, torch.zeros_like(padding_rows), atol=1e-6, rtol=1e-6)
138+
139+
140+
def test_weight_initialization_matches_hf():
141+
from transformers import AutoConfig, set_seed
142+
from transformers.models.esm.modeling_esm import EsmForMaskedLM
143+
144+
from esm.convert import convert_esm_hf_to_te
145+
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
146+
147+
set_seed(42)
148+
149+
config_hf = AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", vocab_size=64)
150+
model_hf = EsmForMaskedLM(config_hf)
151+
model_te_converted = convert_esm_hf_to_te(model_hf)
152+
153+
config = NVEsmConfig(**model_hf.config.to_dict())
154+
model_te = NVEsmForMaskedLM(config)
155+
model_te.to("cuda")
156+
model_te_converted.to("cuda")
157+
158+
state_dict_hf = model_te_converted.state_dict()
159+
state_dict_te = model_te.state_dict()
160+
161+
for name in state_dict_hf.keys():
162+
if name.endswith("_extra_state"):
163+
continue
164+
165+
torch.testing.assert_close(
166+
state_dict_te[name].mean(),
167+
state_dict_hf[name].mean(),
168+
atol=1e-3,
169+
rtol=1e-4,
170+
msg=lambda x: f"Mean mismatch for parameter {name}: {x}",
171+
)
172+
173+
torch.testing.assert_close(
174+
state_dict_te[name].std(),
175+
state_dict_hf[name].std(),
176+
atol=1e-3,
177+
rtol=1e-4,
178+
msg=lambda x: f"Std mismatch for parameter {name}: {x}",
179+
)

0 commit comments

Comments
 (0)