Skip to content

Commit d7911a4

Browse files
committed
fix mtp export
Signed-off-by: jenchen13 <jennifchen@nvidia.com>
1 parent 5cf6d06 commit d7911a4

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

modelopt/torch/export/unified_export_megatron.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ def save_pretrained(
377377
# Add multimodal components to state_dict
378378
state_dict.update(multimodal_state_dict)
379379

380+
mtp_state_dict = self._get_mtp_state_dict()
381+
state_dict.update(mtp_state_dict)
382+
print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors")
383+
380384
# Barrier to ensure the export_dir has been created.
381385
torch.distributed.barrier()
382386

@@ -478,9 +482,7 @@ def _get_state_dict(self):
478482
else:
479483
raise ValueError("Only TransformerLayer or MambaLayer are supported.")
480484

481-
# Get MTP layer if exists. Only on rank 0 to avoid duplicate weights.
482-
if torch.distributed.get_rank() == 0:
483-
self._get_mtp_state_dict()
485+
# TODO export MTP layer in the future
484486

485487
def _get_transformer_layer_state_dict(self, layer, layer_id):
486488
if not isinstance(layer.input_layernorm, IdentityOp):
@@ -558,13 +560,14 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
558560
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id)
559561
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id)
560562

561-
def _get_mtp_state_dict(self):
563+
def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
562564
"""Export the MTP module.
563565
564566
Currently, we copy the BF16 MTP weights from the pretrained model if the pretrained model has MTP layers.
565567
"""
566568
# TODO Implement MTP export for quantized MTP
567569
# Hacky version for now: copy MTP weights from pretrained model
570+
mtp_state_dict = {}
568571
if self._hf_pretrained_model_name:
569572
if os.path.isdir(self._hf_pretrained_model_name):
570573
safetensors_index_file = (
@@ -583,11 +586,12 @@ def _get_mtp_state_dict(self):
583586
model_dir = Path(safetensors_index_file).parent
584587
for key in safetensors_index["weight_map"]:
585588
if key.startswith("mtp.") and key not in self._state_dict:
586-
self._state_dict[key] = get_safetensor(model_dir, key)
589+
mtp_state_dict[key] = get_safetensor(model_dir, key)
587590
mtp_exists = True
588591

589592
if mtp_exists:
590593
self.exclude_modules.append("mtp*")
594+
return mtp_state_dict
591595

592596
def _get_mamba_layer_state_dict(self, layer, layer_id):
593597
if not isinstance(layer.norm, IdentityOp):
@@ -855,7 +859,6 @@ def _name_remapping(
855859
else:
856860
source_key = mapping.get(key, key)
857861
self._state_dict[prefix + source_key] = val
858-
print(f"{prefix + source_key}: {self._state_dict[prefix + source_key].dtype}")
859862

860863
def _gated_mlp_slicing(
861864
self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj"

0 commit comments

Comments
 (0)