Skip to content

Commit 3e937ff

Browse files
akoumpaedjson
authored andcommitted
fix: skip embedding[padding_idx] = 0 with TP (NVIDIA-NeMo#1675)
* skip embedding[padding_idx] = 0 Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * remove code Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 4a997bd commit 3e937ff

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,17 @@ def initialize_model_weights(
548548
and getattr(model.config, "n_routed_experts", None) # is Nemotron V3
549549
and hasattr(model, "backbone") # is HF remote code
550550
)
551+
# HF's _init_weights calls init.zeros_(weight[padding_idx]) on
552+
# nn.Embedding layers. When the weight is a DTensor (TP-sharded),
553+
# the integer index triggers a redistribute that fails. Temporarily
554+
# clear padding_idx so the zeroing is skipped, then restore it and
555+
# zero the row via local-tensor ops instead.
556+
has_padding_idx = any(
557+
isinstance(mod, nn.Embedding)
558+
and type(mod.weight).__name__ == "DTensor"
559+
and getattr(mod, "padding_idx", None) is not None
560+
for mod in model.modules()
561+
)
551562
skip_initialize_weights = (
552563
model_class
553564
in [
@@ -556,6 +567,7 @@ def initialize_model_weights(
556567
]
557568
or is_nemotron_v2
558569
or is_nemotron_v3_hf
570+
or has_padding_idx
559571
)
560572
if not skip_initialize_weights:
561573
for _, module in model.named_modules():
@@ -566,7 +578,8 @@ def initialize_model_weights(
566578
model.initialize_weights()
567579
else:
568580
logging.warning(
569-
"Warning: Model does not have initialize_weights method. Requires custom initialization to be implemented."
581+
"Warning: Model does not have initialize_weights method."
582+
" Requires custom initialization to be implemented."
570583
)
571584

572585
if peft_init_method is not None:

0 commit comments

Comments
 (0)