Skip to content

Commit 277a990

Browse files
cjluo-nvEdwardf0t1
authored andcommitted
Support Qwen3 Next MTP load and export (#860)
Fix MTP export for Qwen3 Next **Overview:** ? For Qwen3 next, the MTP weights are not stored separately in safetensors. So we use "mtp" weights key to decide if the weights are for MTP or not. Qwen3 Next PTQ and check if MTP is in the exported checkpoint. scripts/huggingface_example.sh --model <Qwen3-Next-80B-A3B-Instruct/Thinking> --quant nvfp4 --trust_remote_code <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **Refactor** * Optimized Multi-Token Prediction weight loading with improved layer detection and handling. * **Chores** * Simplified status reporting to display total loaded weights and detected layers. * Removed verbose per-file warnings for cleaner console output. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Co-authored-by: Zhiyu <zhiyuc@nvidia.com>
1 parent dec7161 commit 277a990

File tree

3 files changed

+62
-72
lines changed

3 files changed

+62
-72
lines changed

examples/llm_ptq/example_utils.py

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import inspect
1919
import json
2020
import os
21-
import re
2221
import shutil
2322
import sys
2423
import warnings
@@ -317,8 +316,10 @@ def get_processor(
317316
return None
318317

319318

320-
def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]:
321-
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).
319+
def load_mtp_weights(
320+
model: torch.nn.Module, model_path: str
321+
) -> tuple[list[str], dict[str, torch.Tensor]]:
322+
"""Load MTP weights from the model checkpoint.
322323
323324
Some models store additional layers in separate safetensors files with non-standard
324325
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
@@ -334,87 +335,76 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[
334335
List of layer prefixes that were loaded from non-standard safetensors files.
335336
These layers should typically be excluded from quantization.
336337
Empty list if no additional weights were loaded.
338+
Dictionary of MTP weights that were not loaded into the model state dict.
337339
"""
338340
model_path = Path(model_path)
339341
index_file = model_path / "model.safetensors.index.json"
340-
mtp_layer_prefixes: list[str] = []
341342

342343
if not index_file.exists():
343-
return mtp_layer_prefixes
344+
return [], {}
344345

345346
# Load the index to find all referenced safetensors files
346-
with open(index_file) as f:
347-
index = json.load(f)
348-
349-
# Find all unique safetensors files referenced
350-
all_files = set(index["weight_map"].values())
351-
352-
# Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern)
353-
standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors")
354-
non_standard_files = [f for f in all_files if not standard_pattern.match(f)]
347+
index = json.load(open(index_file))
348+
weight_map = index["weight_map"]
349+
# Find all files in weight_map whose key or value contains "mtp"
350+
mtp_weight_map = {}
351+
for k, v in weight_map.items():
352+
if "mtp" in k or "mtp" in v:
353+
mtp_weight_map.setdefault(v, []).append(k)
354+
355+
if not mtp_weight_map:
356+
return [], {}
357+
358+
def _extract_layer_prefixes(keys):
359+
mtp_layer_prefixes = set()
360+
for key in keys:
361+
parts = key.split(".")
362+
for i, part in enumerate(parts):
363+
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
364+
prefix = ".".join(parts[: i + 2])
365+
mtp_layer_prefixes.add(prefix)
366+
break
355367

356-
if not non_standard_files:
357368
return mtp_layer_prefixes
358369

370+
# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
371+
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
372+
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)
373+
359374
# Check which non-standard files exist and have missing weights
360375
model_state = model.state_dict()
361376
total_loaded = 0
362377

363-
for filename in non_standard_files:
378+
not_in_state_dict = {}
379+
380+
for filename, mtp_keys in mtp_weight_map.items():
364381
filepath = model_path / filename
365382
if not filepath.exists():
366383
continue
367384

368-
# Find keys that should be in this file
369-
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]
370-
371-
# Check which are missing from the model
372-
missing_keys = [k for k in expected_keys if k not in model_state]
373-
374-
if not missing_keys:
375-
# Even if weights are loaded, record the layer prefixes for exclusion
376-
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
377-
for key in expected_keys:
378-
# Extract layer prefix like "model.layers.92" or "layers.92"
379-
parts = key.split(".")
380-
for i, part in enumerate(parts):
381-
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
382-
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
383-
if prefix not in mtp_layer_prefixes:
384-
mtp_layer_prefixes.append(prefix)
385-
break
386-
continue
387-
388-
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
389-
390-
# Extract unique layer prefixes for exclusion from quantization
391-
for key in missing_keys:
392-
parts = key.split(".")
393-
for i, part in enumerate(parts):
394-
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
395-
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
396-
if prefix not in mtp_layer_prefixes:
397-
mtp_layer_prefixes.append(prefix)
398-
break
399-
400-
# Load the weights to CPU first, load_state_dict will handle device placement
385+
print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
401386
weights = load_file(str(filepath), device="cpu")
402-
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
403-
404-
# Load into model
405-
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
406-
total_loaded += len(weights_to_load)
387+
weights = {k: v for k, v in weights.items() if k in mtp_keys}
388+
# Load the MTP weights to the model state dict
389+
in_state_dict = {k: weights[k] for k in weights if k in model_state}
390+
not_in_state_dict = not_in_state_dict | {
391+
k: weights[k] for k in weights if k not in model_state
392+
}
407393

408-
if missing:
409-
print(f" Warning: {len(missing)} keys still missing after loading {filename}")
394+
if in_state_dict:
395+
model.load_state_dict(in_state_dict, strict=False)
396+
total_loaded += len(in_state_dict)
410397

411398
if total_loaded > 0:
412-
print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files")
399+
print(
400+
f"✓ Successfully loaded {total_loaded} MTP weights, "
401+
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
402+
)
413403

414404
if mtp_layer_prefixes:
415405
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
416406

417-
return mtp_layer_prefixes
407+
return list(mtp_layer_prefixes), not_in_state_dict
418408

419409

420410
def get_dtype(dtype):
@@ -576,12 +566,6 @@ def get_model(
576566
if device == "cuda" and not is_model_on_gpu(model):
577567
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
578568

579-
# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
580-
# Store the MTP layer prefixes on the model for later exclusion from quantization
581-
mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path)
582-
if mtp_layer_prefixes:
583-
model._mtp_layer_prefixes = mtp_layer_prefixes
584-
585569
return model
586570

587571

examples/llm_ptq/hf_ptq.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
get_tokenizer,
3232
is_enc_dec,
3333
is_nemotron_vl,
34-
load_mtp_weights_if_needed,
34+
load_mtp_weights,
3535
run_nemotron_vl_preview,
3636
)
3737
from torch.utils.data import DataLoader
@@ -359,12 +359,6 @@ def load_model(args: argparse.Namespace):
359359
)
360360
calibration_only = True
361361

362-
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
363-
# Store the MTP layer prefixes on the model for later exclusion from quantization
364-
mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path)
365-
if mtp_layer_prefixes:
366-
full_model._mtp_layer_prefixes = mtp_layer_prefixes
367-
368362
model_type = get_model_type(full_model)
369363

370364
device = full_model.device
@@ -720,9 +714,17 @@ def _compute_perplexity(model, data, batch_size: int = 1):
720714
print(f"Saving model to {args.export_path}")
721715
full_model.save_pretrained(args.export_path)
722716
else:
717+
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
718+
# Store the MTP layer prefixes on the model for later exclusion from quantization
719+
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path)
720+
721+
if mtp_layer_prefixes:
722+
full_model._mtp_layer_prefixes = mtp_layer_prefixes
723+
723724
export_hf_checkpoint(
724725
full_model,
725726
export_dir=export_path,
727+
extra_state_dict=mtp_state_dict,
726728
)
727729

728730
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def export_hf_checkpoint(
960960
export_dir: Path | str = tempfile.gettempdir(),
961961
save_modelopt_state: bool = False,
962962
components: list[str] | None = None,
963+
extra_state_dict: dict[str, torch.Tensor] | None = None,
963964
):
964965
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).
965966
@@ -976,6 +977,7 @@ def export_hf_checkpoint(
976977
save_modelopt_state: Whether to save the modelopt state_dict.
977978
components: Only used for diffusers pipelines. Optional list of component names
978979
to export. If None, all quantized components are exported.
980+
extra_state_dict: Extra state dictionary to add to the exported model.
979981
"""
980982
export_dir = Path(export_dir)
981983
export_dir.mkdir(parents=True, exist_ok=True)
@@ -1012,7 +1014,9 @@ def export_hf_checkpoint(
10121014

10131015
# Save model
10141016
model.save_pretrained(
1015-
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
1017+
export_dir,
1018+
state_dict={**post_state_dict, **(extra_state_dict or {})},
1019+
save_modelopt_state=save_modelopt_state,
10161020
)
10171021

10181022
original_config = f"{export_dir}/config.json"

0 commit comments

Comments
 (0)