Skip to content

Commit 9e313ad

Browse files
cjluo-nvEdwardf0t1
authored andcommitted
Support Qwen3 Next MTP load and export (#860)
## What does this PR do? Fix MTP export for Qwen3 Next **Overview:** ? For Qwen3 next, the MTP weights are not stored separately in safetensors. So we use "mtp" weights key to decide if the weights are for MTP or not. ## Testing Qwen3 Next PTQ and check if MTP is in the exported checkpoint. scripts/huggingface_example.sh --model <Qwen3-Next-80B-A3B-Instruct/Thinking> --quant nvfp4 --trust_remote_code ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Optimized Multi-Token Prediction weight loading with improved layer detection and handling. * **Chores** * Simplified status reporting to display total loaded weights and detected layers. * Removed verbose per-file warnings for cleaner console output. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Co-authored-by: Zhiyu <zhiyuc@nvidia.com>
1 parent 733ede0 commit 9e313ad

File tree

3 files changed

+62
-72
lines changed

3 files changed

+62
-72
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import inspect
1919
import json
2020
import os
21-
import re
2221
import shutil
2322
import sys
2423
import warnings
@@ -317,8 +316,10 @@ def get_processor(
317316
return None
318317

319318

320-
def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
321-
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).
319+
def load_mtp_weights(
320+
model: torch.nn.Module, model_path: str
321+
) -> tuple[list[str], dict[str, torch.Tensor]]:
322+
"""Load MTP weights from the model checkpoint.
322323
323324
Some models store additional layers in separate safetensors files with non-standard
324325
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
@@ -334,87 +335,76 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[
334335
List of layer prefixes that were loaded from non-standard safetensors files.
335336
These layers should typically be excluded from quantization.
336337
Empty list if no additional weights were loaded.
338+
Dictionary of MTP weights that were not loaded into the model state dict.
337339
"""
338340
model_path = Path(model_path)
339341
index_file = model_path / "model.safetensors.index.json"
340-
mtp_layer_prefixes: list[str] = []
341342

342343
if not index_file.exists():
343-
return mtp_layer_prefixes
344+
return [], {}
344345

345346
# Load the index to find all referenced safetensors files
346-
with open(index_file) as f:
347-
index = json.load(f)
348-
349-
# Find all unique safetensors files referenced
350-
all_files = set(index["weight_map"].values())
351-
352-
# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)
353-
standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
354-
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]
347+
index = json.load(open(index_file))
348+
weight_map = index["weight_map"]
349+
# Find all files in weight_map whose key or value contains "mtp"
350+
mtp_weight_map = {}
351+
for k, v in weight_map.items():
352+
if "mtp" in k or "mtp" in v:
353+
mtp_weight_map.setdefault(v, []).append(k)
354+
355+
if not mtp_weight_map:
356+
return [], {}
357+
358+
def _extract_layer_prefixes(keys):
359+
mtp_layer_prefixes = set()
360+
for key in keys:
361+
parts = key.split(".")
362+
for i, part in enumerate(parts):
363+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
364+
prefix = ".".join(parts[: i + 2])
365+
mtp_layer_prefixes.add(prefix)
366+
break
355367

356-
if not non_standard_files:
357368
return mtp_layer_prefixes
358369

370+
# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
371+
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
372+
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)
373+
359374
# Check which non-standard files exist and have missing weights
360375
model_state = model.state_dict()
361376
total_loaded = 0
362377

363-
for filename in non_standard_files:
378+
not_in_state_dict = {}
379+
380+
for filename, mtp_keys in mtp_weight_map.items():
364381
filepath = model_path / filename
365382
if not filepath.exists():
366383
continue
367384

368-
# Find keys that should be in this file
369-
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]
370-
371-
# Check which are missing from the model
372-
missing_keys = [k for k in expected_keys if k not in model_state]
373-
374-
if not missing_keys:
375-
# Even if weights are loaded, record the layer prefixes for exclusion
376-
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
377-
for key in expected_keys:
378-
# Extract layer prefix like "model.layers.92" or "layers.92"
379-
parts = key.split(".")
380-
for i, part in enumerate(parts):
381-
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
382-
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
383-
if prefix not in mtp_layer_prefixes:
384-
mtp_layer_prefixes.append(prefix)
385-
break
386-
continue
387-
388-
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
389-
390-
# Extract unique layer prefixes for exclusion from quantization
391-
for key in missing_keys:
392-
parts = key.split(".")
393-
for i, part in enumerate(parts):
394-
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
395-
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
396-
if prefix not in mtp_layer_prefixes:
397-
mtp_layer_prefixes.append(prefix)
398-
break
399-
400-
# Load the weights to CPU first, load_state_dict will handle device placement
385+
print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
401386
weights = load_file(str(filepath), device="cpu")
402-
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
403-
404-
# Load into model
405-
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
406-
total_loaded += len(weights_to_load)
387+
weights = {k: v for k, v in weights.items() if k in mtp_keys}
388+
# Load the MTP weights to the model state dict
389+
in_state_dict = {k: weights[k] for k in weights if k in model_state}
390+
not_in_state_dict = not_in_state_dict | {
391+
k: weights[k] for k in weights if k not in model_state
392+
}
407393

408-
if missing:
409-
print(f" Warning: {len(missing)} keys still missing after loading {filename}")
394+
if in_state_dict:
395+
model.load_state_dict(in_state_dict, strict=False)
396+
total_loaded += len(in_state_dict)
410397

411398
if total_loaded > 0:
412-
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")
399+
print(
400+
f"✓ Successfully loaded {total_loaded} MTP weights, "
401+
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
402+
)
413403

414404
if mtp_layer_prefixes:
415405
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
416406

417-
return mtp_layer_prefixes
407+
return list(mtp_layer_prefixes), not_in_state_dict
418408

419409

420410
def get_dtype(dtype):
@@ -576,12 +566,6 @@ def get_model(
576566
if device == "cuda" and not is_model_on_gpu(model):
577567
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
578568

579-
# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
580-
# Store the MTP layer prefixes on the model for later exclusion from quantization
581-
mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path)
582-
if mtp_layer_prefixes:
583-
model._mtp_layer_prefixes = mtp_layer_prefixes
584-
585569
return model
586570

587571

examples/llm_ptq/hf_ptq.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
get_tokenizer,
3232
is_enc_dec,
3333
is_nemotron_vl,
34-
load_mtp_weights_if_needed,
34+
load_mtp_weights,
3535
run_nemotron_vl_preview,
3636
)
3737
from torch.utils.data import DataLoader
@@ -349,12 +349,6 @@ def load_model(args: argparse.Namespace):
349349
)
350350
calibration_only = True
351351

352-
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
353-
# Store the MTP layer prefixes on the model for later exclusion from quantization
354-
mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)
355-
if mtp_layer_prefixes:
356-
full_model._mtp_layer_prefixes = mtp_layer_prefixes
357-
358352
model_type = get_model_type(full_model)
359353

360354
device = full_model.device
@@ -632,9 +626,17 @@ def export_quantized(
632626
"They will be set at deployment time."
633627
)
634628

629+
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
630+
# Store the MTP layer prefixes on the model for later exclusion from quantization
631+
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path)
632+
633+
if mtp_layer_prefixes:
634+
full_model._mtp_layer_prefixes = mtp_layer_prefixes
635+
635636
export_hf_checkpoint(
636637
full_model,
637638
export_dir=export_path,
639+
extra_state_dict=mtp_state_dict,
638640
)
639641

640642
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def export_hf_checkpoint(
960960
export_dir: Path | str = tempfile.gettempdir(),
961961
save_modelopt_state: bool = False,
962962
components: list[str] | None = None,
963+
extra_state_dict: dict[str, torch.Tensor] | None = None,
963964
):
964965
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).
965966
@@ -976,6 +977,7 @@ def export_hf_checkpoint(
976977
save_modelopt_state: Whether to save the modelopt state_dict.
977978
components: Only used for diffusers pipelines. Optional list of component names
978979
to export. If None, all quantized components are exported.
980+
extra_state_dict: Extra state dictionary to add to the exported model.
979981
"""
980982
export_dir = Path(export_dir)
981983
export_dir.mkdir(parents=True, exist_ok=True)
@@ -1012,7 +1014,9 @@ def export_hf_checkpoint(
10121014

10131015
# Save model
10141016
model.save_pretrained(
1015-
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
1017+
export_dir,
1018+
state_dict={**post_state_dict, **(extra_state_dict or {})},
1019+
save_modelopt_state=save_modelopt_state,
10161020
)
10171021

10181022
original_config = f"{export_dir}/config.json"

0 commit comments

Comments
 (0)