Skip to content

Commit 8cee6ad

Browse files
cp: fix: skip embedding[padding_idx] = 0 with TP (1675) into r0.4.0 (#1771)
fix: skip embedding[padding_idx] = 0 with TP (#1675) * skip embedding[padding_idx] = 0 * fix * remove code --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com> Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
1 parent 75a0770 commit 8cee6ad

1 file changed

Lines changed: 14 additions & 1 deletion

File tree

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,17 @@ def initialize_model_weights(
525525
and getattr(model.config, "n_routed_experts", None) # is Nemotron V3
526526
and hasattr(model, "backbone") # is HF remote code
527527
)
528+
# HF's _init_weights calls init.zeros_(weight[padding_idx]) on
529+
# nn.Embedding layers. When the weight is a DTensor (TP-sharded),
530+
# the integer index triggers a redistribute that fails. Temporarily
531+
# clear padding_idx so the zeroing is skipped, then restore it and
532+
# zero the row via local-tensor ops instead.
533+
has_padding_idx = any(
534+
isinstance(mod, nn.Embedding)
535+
and type(mod.weight).__name__ == "DTensor"
536+
and getattr(mod, "padding_idx", None) is not None
537+
for mod in model.modules()
538+
)
528539
skip_initialize_weights = (
529540
model_class
530541
in [
@@ -533,6 +544,7 @@ def initialize_model_weights(
533544
]
534545
or is_nemotron_v2
535546
or is_nemotron_v3_hf
547+
or has_padding_idx
536548
)
537549
if not skip_initialize_weights:
538550
for _, module in model.named_modules():
@@ -543,7 +555,8 @@ def initialize_model_weights(
543555
model.initialize_weights()
544556
else:
545557
logging.warning(
546-
"Warning: Model does not have initialize_weights method. Requires custom initialization to be implemented."
558+
"Warning: Model does not have initialize_weights method."
559+
" Requires custom initialization to be implemented."
547560
)
548561

549562
if peft_init_method is not None:

0 commit comments

Comments
 (0)