Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 175 additions & 64 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

import torch
import torch.distributed
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from safetensors import safe_open
from safetensors.torch import save_file

from modelopt import __version__
Expand All @@ -47,11 +44,7 @@
)
from .plugins.hf_checkpoint_utils import copy_hf_ckpt_remote_code, load_multimodal_components
from .plugins.mcore_common import all_mcore_hf_export_mapping
from .plugins.mcore_custom import (
CustomModuleMapping,
get_safetensor,
save_safetensors_by_layer_index,
)
from .plugins.mcore_custom import CustomModuleMapping, save_safetensors_by_layer_index
from .plugins.megatron_importer import GPTModelImporter
from .quant_utils import (
get_activation_scaling_factor,
Expand All @@ -71,9 +64,11 @@
has_mcore = False
with import_plugin("megatron"):
from megatron.core.models.gpt import GPTModel
from megatron.core.models.hybrid.hybrid_model import HybridModel
from megatron.core.models.mamba import MambaModel
from megatron.core.models.multimodal.llava_model import LLaVAModel
from megatron.core.parallel_state import (
get_expert_model_parallel_rank,
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -121,7 +116,7 @@ def __init__(
moe_router_dtype: str | None = None,
):
"""Create a GPTModel exporter instance."""
if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)):
if not isinstance(model, (GPTModel, HybridModel, MambaModel, LLaVAModel)):
raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!")

self._state_dict = OrderedDict()
Expand Down Expand Up @@ -263,10 +258,11 @@ def save_pretrained(

# We use the 1st PP rank to handle VLM because vision_models
# and vision_proj only exist in the first stage.
is_first_stage_main_rank = pp_rank == 0 and tp_rank == 0
ep_rank = get_expert_model_parallel_rank()
is_first_stage_main_rank = pp_rank == 0 and tp_rank == 0 and ep_rank == 0
# We use the last PP rank to write the config because
# medusa_heads and eagle_module only exist in the last stage.
is_last_stage_main_rank = pp_rank == pp_size - 1 and tp_rank == 0
is_last_stage_main_rank = pp_rank == pp_size - 1 and tp_rank == 0 and ep_rank == 0

# Main export process
layer_state_dicts = self.layer_state_dicts
Expand Down Expand Up @@ -351,9 +347,12 @@ def save_pretrained(
if is_last_stage_main_rank and self._hf_config is not None:
copy_hf_ckpt_remote_code(pretrained_model_name_or_path, save_directory)

# Barrier after config copy to ensure config.json is fully written.
torch.distributed.barrier()

# Newer versions of VLLM expect config.json with hf_quant_config
config_json_file = save_directory + "/config.json"
if self._hf_quant_config and os.path.exists(config_json_file):
if is_last_stage_main_rank and self._hf_quant_config and os.path.exists(config_json_file):
converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config)
with open(config_json_file) as f:
config_dict = json.load(f)
Expand Down Expand Up @@ -529,65 +528,177 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id)

def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
"""Export the MTP module.
"""Export the MTP module using the rules system for proper quantized weight handling.

Currently, we copy the BF16 MTP weights from the pretrained model if the pretrained model has MTP layers.
Supports both repeated MTP (single MultiTokenPredictionLayer with multiple inner
model layers, used by Nemotron) and non-repeated MTP (multiple MTP layers each
with one inner layer, used by DeepSeek).
"""
# TODO Implement MTP export for quantized MTP
# Hacky version for now: copy MTP weights from pretrained model
mtp_state_dict = {}
if not self._hf_pretrained_model_name:
return mtp_state_dict
mtp = getattr(self.model, "mtp", None)
if mtp is None or not hasattr(mtp, "layers") or len(mtp.layers) == 0:
return {}

mtp_exists = False
# Save current state_dict and create a fresh one for MTP export
saved_state_dict = self._state_dict
self._state_dict = OrderedDict()

if os.path.isdir(self._hf_pretrained_model_name):
safetensors_index_file = (
Path(self._hf_pretrained_model_name) / "model.safetensors.index.json"
)
single_safetensors_file = Path(self._hf_pretrained_model_name) / "model.safetensors"
else:
try:
safetensors_index_file = Path(
hf_hub_download(
repo_id=self._hf_pretrained_model_name,
filename="model.safetensors.index.json",
try:
if len(mtp.layers) == 1:
self._export_repeated_mtp(mtp)
else:
self._export_non_repeated_mtp(mtp)
mtp_state_dict = self._state_dict
finally:
self._state_dict = saved_state_dict

if len(mtp_state_dict) > 0:
print(f"Exporting MTP: {len(mtp_state_dict)} tensors via rules system")

return mtp_state_dict

def _build_mtp_inner_rules(self):
"""Build rules with 'backbone'/'model' prefix replaced by 'mtp' for MTP inner layers.

This mirrors the import side's ``is_mtp=True`` behaviour which replaces 'backbone'
with 'mtp' (or 'model' with 'mtp') so that decoder layer export methods can be
reused for MTP inner layers.
"""
arch_mapping = all_mcore_hf_export_mapping.get(self.arch, {})

method_map = {
"name_remapping": self._name_remapping,
"qkv_slicing": self._qkv_slicing,
"self_attention_scaling": self._self_attention_scaling,
"gated_mlp_slicing": self._gated_mlp_slicing,
"grouped_mlp_slicing": self._grouped_mlp_slicing,
"pack_name_remapping": self._pack_name_remapping,
"pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss,
}

mtp_rules = {}
for key, mapping in arch_mapping.items():
if key.startswith("mtp."):
# MTP-specific rules are used directly, not prefix-replaced
continue
if isinstance(mapping, CustomModuleMapping):
prefix = mapping.target_name_or_prefix
if "backbone" in prefix:
mtp_prefix = prefix.replace("backbone", "mtp")
elif prefix.startswith("model."):
mtp_prefix = "mtp." + prefix[len("model.") :]
else:
continue
func = method_map.get(mapping.func_name)
if func is None:
continue
func_kwargs = dict(mapping.func_kwargs)
mtp_rules[key] = (
lambda m, *args, _f=func, _p=mtp_prefix, _kw=func_kwargs, **kwargs: _f(
m, _p.format(*args), **{**_kw, **kwargs}
)
)
single_safetensors_file = None
except EntryNotFoundError:
# Model uses a single unsharded safetensors file — check it for MTP weights.
safetensors_index_file = None
try:
single_safetensors_file = Path(
hf_hub_download(
repo_id=self._hf_pretrained_model_name,
filename="model.safetensors",
)
elif isinstance(mapping, bool):
mtp_rules[key] = mapping

return mtp_rules

def _export_repeated_mtp(self, mtp):
"""Export repeated MTP (single MultiTokenPredictionLayer with multiple inner layers).

Used by architectures like Nemotron where ``model.mtp.layers`` has a single entry
and ``mtp.layers[0].mtp_model_layer.layers`` contains the actual decoder layers.
"""
mtp_layer = mtp.layers[0]
layer_id = 0

# Export MTP-specific modules (enorm, hnorm, eh_proj)
if hasattr(mtp_layer, "enorm") and "mtp.enorm" in self.rules:
self.rules["mtp.enorm"](mtp_layer.enorm, layer_id)
if hasattr(mtp_layer, "hnorm") and "mtp.hnorm" in self.rules:
self.rules["mtp.hnorm"](mtp_layer.hnorm, layer_id)
if hasattr(mtp_layer, "eh_proj") and "mtp.eh_proj" in self.rules:
self.rules["mtp.eh_proj"](mtp_layer.eh_proj, layer_id)

# Export inner transformer/mamba layers with MTP-prefixed rules
saved_rules = self.rules
self.rules = self._build_mtp_inner_rules()

try:
# HybridStack (nested MTP) has .layers; single TransformerLayer does not
if hasattr(mtp_layer.mtp_model_layer, "layers"):
inner_layers = mtp_layer.mtp_model_layer.layers
else:
inner_layers = [mtp_layer.mtp_model_layer]
for inner_layer in inner_layers:
if isinstance(inner_layer, MambaLayer):
self._get_mamba_layer_state_dict(inner_layer, layer_id)
elif isinstance(inner_layer, TransformerLayer):
self._get_transformer_layer_state_dict(inner_layer, layer_id)
else:
raise ValueError(
f"Unsupported MTP inner layer type: {type(inner_layer)}.\n"
"Only TransformerLayer and MambaLayer are supported."
)
except EntryNotFoundError:
return mtp_state_dict

if safetensors_index_file is not None and safetensors_index_file.exists():
print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}")
with open(safetensors_index_file) as f:
safetensors_index = json.load(f)
model_dir = safetensors_index_file.parent
for key in safetensors_index["weight_map"]:
if key.startswith("mtp.") and key not in self._state_dict:
mtp_state_dict[key] = get_safetensor(model_dir, key)
mtp_exists = True
elif single_safetensors_file is not None and single_safetensors_file.exists():
print(f"Exporting MTP: using single safetensors file: {single_safetensors_file}")
with safe_open(str(single_safetensors_file), framework="pt", device="cpu") as f:
for key in f.keys(): # noqa: SIM118
if key.startswith("mtp.") and key not in self._state_dict:
mtp_state_dict[key] = f.get_tensor(key)
mtp_exists = True

if mtp_exists:
self.exclude_modules.append("mtp*")
return mtp_state_dict
layer_id += 1
finally:
self.rules = saved_rules

# Export final_layernorm (lives on MultiTokenPredictionLayer, not mtp_model_layer)
if (
layer_id > 0
and hasattr(mtp_layer, "final_layernorm")
and mtp_layer.final_layernorm is not None
and "mtp.final_layernorm" in self.rules
):
self.rules["mtp.final_layernorm"](mtp_layer.final_layernorm, layer_id - 1)

def _export_non_repeated_mtp(self, mtp):
"""Export non-repeated MTP (multiple MTP layers, each with one inner decoder layer).

Used by architectures like DeepSeek where ``model.mtp.layers`` has one entry per
MTP prediction step. Each MTP layer has its own enorm, hnorm, eh_proj, and a
single-layer decoder.
"""
# Layer ids continue from the last decoder layer
layer_id = self.model.config.num_layers

# Build MTP-prefixed rules for inner layer export
mtp_inner_rules = self._build_mtp_inner_rules()

for mtp_layer in mtp.layers:
# Export MTP-specific modules
if hasattr(mtp_layer, "enorm") and "mtp.enorm" in self.rules:
self.rules["mtp.enorm"](mtp_layer.enorm, layer_id)
if hasattr(mtp_layer, "hnorm") and "mtp.hnorm" in self.rules:
self.rules["mtp.hnorm"](mtp_layer.hnorm, layer_id)
if hasattr(mtp_layer, "eh_proj") and "mtp.eh_proj" in self.rules:
self.rules["mtp.eh_proj"](mtp_layer.eh_proj, layer_id)

# Export inner decoder layers with MTP-prefixed rules
if hasattr(mtp_layer.mtp_model_layer, "layers"):
inner_layers = mtp_layer.mtp_model_layer.layers
else:
inner_layers = [mtp_layer.mtp_model_layer]
saved_rules = self.rules
self.rules = mtp_inner_rules
try:
for inner_layer in inner_layers:
if isinstance(inner_layer, TransformerLayer):
self._get_transformer_layer_state_dict(inner_layer, layer_id)
elif isinstance(inner_layer, MambaLayer):
self._get_mamba_layer_state_dict(inner_layer, layer_id)
finally:
self.rules = saved_rules

# Export final_layernorm (lives on MultiTokenPredictionLayer itself)
if (
hasattr(mtp_layer, "final_layernorm")
and mtp_layer.final_layernorm is not None
and "mtp.final_layernorm" in self.rules
):
self.rules["mtp.final_layernorm"](mtp_layer.final_layernorm, layer_id)

layer_id += 1

def _get_mamba_layer_state_dict(self, layer, layer_id):
if not isinstance(layer.norm, IdentityOp):
Expand Down
Loading