diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index 6b1c320ee9..b7c9e91924 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -520,6 +520,17 @@ def initialize_model_weights( and getattr(model.config, "n_routed_experts", None) # is Nemotron V3 and hasattr(model, "backbone") # is HF remote code ) + # HF's _init_weights calls init.zeros_(weight[padding_idx]) on + # nn.Embedding layers. When the weight is a DTensor (TP-sharded), + # the integer index triggers a redistribute that fails. Temporarily + # clear padding_idx so the zeroing is skipped, then restore it and + # zero the row via local-tensor ops instead. + has_padding_idx = any( + isinstance(mod, nn.Embedding) + and type(mod.weight).__name__ == "DTensor" + and getattr(mod, "padding_idx", None) is not None + for mod in model.modules() + ) skip_initialize_weights = ( model_class in [ @@ -528,6 +539,7 @@ def initialize_model_weights( ] or is_nemotron_v2 or is_nemotron_v3_hf + or has_padding_idx ) if not skip_initialize_weights: for _, module in model.named_modules(): @@ -538,7 +550,8 @@ def initialize_model_weights( model.initialize_weights() else: logging.warning( - "Warning: Model does not have initialize_weights method. Requires custom initialization to be implemented." + "Warning: Model does not have initialize_weights method." + " Requires custom initialization to be implemented." ) if peft_init_method is not None: