diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index e8f5575d36..93687a8d01 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -18,7 +18,6 @@ import inspect import json import os -import re import shutil import sys import warnings @@ -317,8 +316,10 @@ def get_processor( return None -def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[str]: - """Load MTP weights from separate safetensors if needed (e.g., GLM-4.7). +def load_mtp_weights( + model: torch.nn.Module, model_path: str +) -> tuple[list[str], dict[str, torch.Tensor]]: + """Load MTP weights from the model checkpoint. Some models store additional layers in separate safetensors files with non-standard names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these @@ -334,87 +335,76 @@ def load_mtp_weights_if_needed(model: torch.nn.Module, model_path: str) -> list[ List of layer prefixes that were loaded from non-standard safetensors files. These layers should typically be excluded from quantization. Empty list if no additional weights were loaded. + Dictionary of MTP weights that were not loaded into the model state dict. """ model_path = Path(model_path) index_file = model_path / "model.safetensors.index.json" - mtp_layer_prefixes: list[str] = [] if not index_file.exists(): - return mtp_layer_prefixes + return [], {} # Load the index to find all referenced safetensors files - with open(index_file) as f: - index = json.load(f) - - # Find all unique safetensors files referenced - all_files = set(index["weight_map"].values()) - - # Find non-standard shard files (not matching model-XXXXX-of-XXXXX.safetensors pattern) - standard_pattern = re.compile(r"model-\d{5}-of-\d{5}\.safetensors") - non_standard_files = [f for f in all_files if not standard_pattern.match(f)] + index = json.load(open(index_file)) + weight_map = index["weight_map"] + # Find all files in weight_map whose key or value contains "mtp" + mtp_weight_map = {} + for k, v in weight_map.items(): + if "mtp" in k or "mtp" in v: + mtp_weight_map.setdefault(v, []).append(k) + + if not mtp_weight_map: + return [], {} + + def _extract_layer_prefixes(keys): + mtp_layer_prefixes = set() + for key in keys: + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + prefix = ".".join(parts[: i + 2]) + mtp_layer_prefixes.add(prefix) + break - if not non_standard_files: return mtp_layer_prefixes + # Flatten mtp_weight_map.values() (list of list of str) to a single list of str + mtp_keys = [k for keys in mtp_weight_map.values() for k in keys] + mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys) + # Check which non-standard files exist and have missing weights model_state = model.state_dict() total_loaded = 0 - for filename in non_standard_files: + not_in_state_dict = {} + + for filename, mtp_keys in mtp_weight_map.items(): filepath = model_path / filename if not filepath.exists(): continue - # Find keys that should be in this file - expected_keys = [k for k, v in index["weight_map"].items() if v == filename] - - # Check which are missing from the model - missing_keys = [k for k in expected_keys if k not in model_state] - - if not missing_keys: - # Even if weights are loaded, record the layer prefixes for exclusion - # Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight") - for key in expected_keys: - # Extract layer prefix like "model.layers.92" or "layers.92" - parts = key.split(".") - for i, part in enumerate(parts): - if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): - prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" - if prefix not in mtp_layer_prefixes: - mtp_layer_prefixes.append(prefix) - break - continue - - print(f"Loading {len(missing_keys)} missing weights from {filename}...") - - # Extract unique layer prefixes for exclusion from quantization - for key in missing_keys: - parts = key.split(".") - for i, part in enumerate(parts): - if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): - prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" - if prefix not in mtp_layer_prefixes: - mtp_layer_prefixes.append(prefix) - break - - # Load the weights to CPU first, load_state_dict will handle device placement + print(f"Loading {len(mtp_keys)} mtp weights from {filename}...") weights = load_file(str(filepath), device="cpu") - weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} - - # Load into model - missing, unexpected = model.load_state_dict(weights_to_load, strict=False) - total_loaded += len(weights_to_load) + weights = {k: v for k, v in weights.items() if k in mtp_keys} + # Load the MTP weights to the model state dict + in_state_dict = {k: weights[k] for k in weights if k in model_state} + not_in_state_dict = not_in_state_dict | { + k: weights[k] for k in weights if k not in model_state + } - if missing: - print(f" Warning: {len(missing)} keys still missing after loading {filename}") + if in_state_dict: + model.load_state_dict(in_state_dict, strict=False) + total_loaded += len(in_state_dict) if total_loaded > 0: - print(f"✓ Successfully loaded {total_loaded} weights from non-standard safetensors files") + print( + f"✓ Successfully loaded {total_loaded} MTP weights, " + f"{len(not_in_state_dict)} MTP weights not in model.state_dict" + ) if mtp_layer_prefixes: print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}") - return mtp_layer_prefixes + return list(mtp_layer_prefixes), not_in_state_dict def get_dtype(dtype): @@ -576,12 +566,6 @@ def get_model( if device == "cuda" and not is_model_on_gpu(model): print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") - # Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors) - # Store the MTP layer prefixes on the model for later exclusion from quantization - mtp_layer_prefixes = load_mtp_weights_if_needed(model, ckpt_path) - if mtp_layer_prefixes: - model._mtp_layer_prefixes = mtp_layer_prefixes - return model diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a5af5e97d4..d9a6ca8934 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -31,7 +31,7 @@ get_tokenizer, is_enc_dec, is_nemotron_vl, - load_mtp_weights_if_needed, + load_mtp_weights, run_nemotron_vl_preview, ) from torch.utils.data import DataLoader @@ -349,12 +349,6 @@ def load_model(args: argparse.Namespace): ) calibration_only = True - # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) - # Store the MTP layer prefixes on the model for later exclusion from quantization - mtp_layer_prefixes = load_mtp_weights_if_needed(full_model, args.pyt_ckpt_path) - if mtp_layer_prefixes: - full_model._mtp_layer_prefixes = mtp_layer_prefixes - model_type = get_model_type(full_model) device = full_model.device @@ -632,9 +626,17 @@ def export_quantized( "They will be set at deployment time." ) + # Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode) + # Store the MTP layer prefixes on the model for later exclusion from quantization + mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path) + + if mtp_layer_prefixes: + full_model._mtp_layer_prefixes = mtp_layer_prefixes + export_hf_checkpoint( full_model, export_dir=export_path, + extra_state_dict=mtp_state_dict, ) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 61bebb51da..5703f45156 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -960,6 +960,7 @@ def export_hf_checkpoint( export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, components: list[str] | None = None, + extra_state_dict: dict[str, torch.Tensor] | None = None, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -976,6 +977,7 @@ def export_hf_checkpoint( save_modelopt_state: Whether to save the modelopt state_dict. components: Only used for diffusers pipelines. Optional list of component names to export. If None, all quantized components are exported. + extra_state_dict: Extra state dictionary to add to the exported model. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -1012,7 +1014,9 @@ def export_hf_checkpoint( # Save model model.save_pretrained( - export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state + export_dir, + state_dict={**post_state_dict, **(extra_state_dict or {})}, + save_modelopt_state=save_modelopt_state, ) original_config = f"{export_dir}/config.json"