1- # coding=utf-8
21# noqa: license-check
32# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
43# SPDX-License-Identifier: LicenseRef-Apache2
3837 MaskedLMOutput ,
3938 TokenClassifierOutput ,
4039)
41- from transformers .modeling_utils import PreTrainedModel
4240from transformers .models .esm .configuration_esm import EsmConfig
43- from transformers .models .esm .modeling_esm import EsmPooler
41+ from transformers .models .esm .modeling_esm import EsmPooler , EsmPreTrainedModel
4442from transformers .utils import logging
4543from transformers .utils .generic import TransformersKwargs
4644
@@ -135,6 +133,10 @@ def __init__(self, config: NVEsmConfig):
135133 """
136134 super ().__init__ ()
137135 self .config = config
136+
137+ def _init_method (x ):
138+ torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range )
139+
138140 self .layers = nn .ModuleList (
139141 [
140142 transformer_engine .pytorch .TransformerLayer (
@@ -156,12 +158,18 @@ def __init__(self, config: NVEsmConfig):
156158 fuse_qkv_params = config .fuse_qkv_params ,
157159 params_dtype = config .dtype ,
158160 window_size = (- 1 , - 1 ),
161+ device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
162+ init_method = _init_method ,
163+ output_layer_init_method = _init_method ,
159164 )
160165 for i in range (config .num_hidden_layers )
161166 ]
162167 )
163168 self .emb_layer_norm_after = transformer_engine .pytorch .LayerNorm (
164- config .hidden_size , eps = config .layer_norm_eps , params_dtype = config .dtype
169+ config .hidden_size ,
170+ eps = config .layer_norm_eps ,
171+ params_dtype = config .dtype ,
172+ device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
165173 )
166174 if config .position_embedding_type == "rotary" :
167175 self .rotary_embeddings = RotaryPositionEmbedding (config .hidden_size // config .num_attention_heads )
@@ -247,7 +255,7 @@ def forward(
247255 )
248256
249257
250- class NVEsmPreTrainedModel (PreTrainedModel ):
258+ class NVEsmPreTrainedModel (EsmPreTrainedModel ):
251259 """An abstract class to handle weights initialization and pretrained model loading."""
252260
253261 config_class = NVEsmConfig
@@ -259,61 +267,22 @@ class NVEsmPreTrainedModel(PreTrainedModel):
259267 "EsmEmbeddings" ,
260268 )
261269
262- def _init_weights (self , module : nn .Module ):
263- """Initialize model weights.
270+ def init_empty_weights (self ):
271+ """Handles moving the model from the meta device to the cuda device and initializing the weights."""
272+ # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight
273+ # initialization we passed them during module creation.
274+ for module in self .modules ():
275+ if hasattr (module , "reset_parameters" ):
276+ module .reset_parameters ()
264277
265- This method ensures that models with randomly-initialized weights get the correct initial value distribution,
266- which can be critical for training stability. We also call this method directly when using meta-device init, as
267- the `to_empty` method does not initialize the weights. While the base Transformers model has a similar method,
268- we need to extend it to handle TE-specific modules.
278+ # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
279+ # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
280+ # deviation.
281+ self .esm .embeddings .word_embeddings .to_empty (device = "cuda" )
282+ self .esm .embeddings .apply (self ._init_weights )
269283
270- Args:
271- module (nn.Module): The module to initialize the weights for.
272- """
273- if isinstance (
274- module , (nn .Linear , transformer_engine .pytorch .Linear , transformer_engine .pytorch .LayerNormLinear )
275- ):
276- # Slightly different from the TF version which uses truncated_normal for initialization
277- # cf https://github.com/pytorch/pytorch/pull/5617
278- module .weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
279- if module .bias is not None :
280- module .bias .data .zero_ ()
281- if isinstance (module , nn .Embedding ):
282- module .weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
283- if module .padding_idx is not None :
284- module .weight .data [module .padding_idx ].zero_ ()
285- if isinstance (module , (nn .LayerNorm , transformer_engine .pytorch .LayerNorm )):
286- module .bias .data .zero_ ()
287- module .weight .data .fill_ (1.0 )
288- if isinstance (module , transformer_engine .pytorch .LayerNormLinear ):
289- if module .layer_norm_bias is not None :
290- module .layer_norm_bias .data .zero_ ()
291- module .layer_norm_weight .data .fill_ (1.0 )
292- if module .layer_norm_bias is not None :
293- module .layer_norm_bias .data .zero_ ()
294- if isinstance (module , transformer_engine .pytorch .LayerNormMLP ):
295- if module .layer_norm_bias is not None :
296- module .layer_norm_bias .data .zero_ ()
297- module .layer_norm_weight .data .fill_ (1.0 )
298- if hasattr (module , "fc1_weight" ) and module .fc1_weight is not None :
299- module .fc1_weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
300- if hasattr (module , "fc2_weight" ) and module .fc2_weight is not None :
301- module .fc2_weight .data .normal_ (mean = 0.0 , std = self .config .initializer_range )
302- if hasattr (module , "fc1_bias" ) and module .fc1_bias is not None and module .fc1_bias .numel () > 0 :
303- module .fc1_bias .data .zero_ ()
304- if hasattr (module , "fc2_bias" ) and module .fc2_bias is not None and module .fc2_bias .numel () > 0 :
305- module .fc2_bias .data .zero_ ()
306- if isinstance (module , RotaryPositionEmbedding ) and hasattr (module , "inv_freq" ):
307- # When we initialize the model with `to_empty`, the `inv_freq` attribute is not initialized, so we need to
308- # re-initialize it here with the correct values.
309- module .inv_freq = RotaryPositionEmbedding (
310- self .config .hidden_size // self .config .num_attention_heads
311- ).inv_freq .to (module .inv_freq .device )
312-
313- @classmethod
314- def get_init_context (cls , is_quantized : bool , _is_ds_init_called : bool ):
315- """Override the default get_init_context method to allow for fp8 model initialization."""
316- return []
284+ # Meta-device init seems to break weight tying, so we re-tie the weights here.
285+ self .tie_weights ()
317286
318287
319288class NVEsmModel (NVEsmPreTrainedModel ):
@@ -516,15 +485,20 @@ def __init__(self, config: NVEsmConfig):
516485 config .hidden_size ,
517486 config .hidden_size ,
518487 params_dtype = config .dtype ,
488+ device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
489+ init_method = lambda x : torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range ),
519490 )
520491
521- self .decoder = transformer_engine .pytorch .LayerNormLinear (
522- config .hidden_size ,
523- config .padded_vocab_size if config .padded_vocab_size is not None else config .vocab_size ,
524- bias = True ,
525- eps = config .layer_norm_eps ,
526- params_dtype = config .dtype ,
527- )
492+ with transformer_engine .pytorch .fp8_model_init (enabled = False ):
493+ self .decoder = transformer_engine .pytorch .LayerNormLinear (
494+ config .hidden_size ,
495+ config .padded_vocab_size if config .padded_vocab_size is not None else config .vocab_size ,
496+ bias = True ,
497+ eps = config .layer_norm_eps ,
498+ params_dtype = config .dtype ,
499+ device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
500+ init_method = lambda x : torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range ),
501+ )
528502
529503 def forward (self , features , ** kwargs ):
530504 """Forward pass of the NVEsmLMHead.
@@ -553,7 +527,12 @@ def __init__(self, config):
553527 )
554528
555529 self .layer_norm = (
556- transformer_engine .pytorch .LayerNorm (config .hidden_size , eps = config .layer_norm_eps )
530+ transformer_engine .pytorch .LayerNorm (
531+ config .hidden_size ,
532+ eps = config .layer_norm_eps ,
533+ params_dtype = config .dtype ,
534+ device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
535+ )
557536 if config .emb_layer_norm_before
558537 else None
559538 )
@@ -648,7 +627,11 @@ def __init__(self, config):
648627 self .esm = NVEsmModel (config , add_pooling_layer = False )
649628 self .dropout = nn .Dropout (config .hidden_dropout_prob )
650629 self .classifier = transformer_engine .pytorch .Linear (
651- config .hidden_size , config .num_labels , params_dtype = config .dtype
630+ config .hidden_size ,
631+ config .num_labels ,
632+ params_dtype = config .dtype ,
633+ device = "meta" if torch .get_default_device () == torch .device ("meta" ) else "cuda" ,
634+ init_method = lambda x : torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range ),
652635 )
653636
654637 self .init_weights ()
0 commit comments