Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,17 @@ def initialize_model_weights(
and getattr(model.config, "n_routed_experts", None) # is Nemotron V3
and hasattr(model, "backbone") # is HF remote code
)
# HF's _init_weights calls init.zeros_(weight[padding_idx]) on
# nn.Embedding layers. When the weight is a DTensor (TP-sharded),
# the integer index triggers a redistribute that fails. Temporarily
# clear padding_idx so the zeroing is skipped, then restore it and
# zero the row via local-tensor ops instead.
has_padding_idx = any(
isinstance(mod, nn.Embedding)
and type(mod.weight).__name__ == "DTensor"
and getattr(mod, "padding_idx", None) is not None
for mod in model.modules()
)
skip_initialize_weights = (
model_class
in [
Expand All @@ -528,6 +539,7 @@ def initialize_model_weights(
]
or is_nemotron_v2
or is_nemotron_v3_hf
or has_padding_idx
Comment thread
akoumpa marked this conversation as resolved.
)
if not skip_initialize_weights:
for _, module in model.named_modules():
Expand All @@ -538,7 +550,8 @@ def initialize_model_weights(
model.initialize_weights()
else:
logging.warning(
"Warning: Model does not have initialize_weights method. Requires custom initialization to be implemented."
"Warning: Model does not have initialize_weights method."
" Requires custom initialization to be implemented."
)

if peft_init_method is not None:
Expand Down
Loading