Skip to content

Commit 01c44e9

Browse files
committed
Fix
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent ae4ae22 commit 01c44e9

2 files changed

Lines changed: 12 additions & 25 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,10 @@ def get_processor(
316316
return None
317317

318318

319-
def load_mtp_weights_if_needed(
319+
def load_mtp_weights(
320320
model: torch.nn.Module, model_path: str
321321
) -> tuple[list[str], dict[str, torch.Tensor]]:
322-
"""Load MTP weights from separate safetensors if needed (e.g., GLM-4.7).
322+
"""Load MTP weights from the model checkpoint.
323323
324324
Some models store additional layers in separate safetensors files with non-standard
325325
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
@@ -335,6 +335,7 @@ def load_mtp_weights_if_needed(
335335
List of layer prefixes that were loaded from non-standard safetensors files.
336336
These layers should typically be excluded from quantization.
337337
Empty list if no additional weights were loaded.
338+
Dictionary of MTP weights that were not loaded into the model state dict.
338339
"""
339340
model_path = Path(model_path)
340341
index_file = model_path / "model.safetensors.index.json"
@@ -565,14 +566,6 @@ def get_model(
565566
if device == "cuda" and not is_model_on_gpu(model):
566567
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
567568

568-
# Load any missing weights from non-standard safetensors files (e.g., GLM-4.7's mtp.safetensors)
569-
# Store the MTP layer prefixes on the model for later exclusion from quantization
570-
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights_if_needed(model, ckpt_path)
571-
if mtp_layer_prefixes:
572-
model._mtp_layer_prefixes = mtp_layer_prefixes
573-
if mtp_state_dict:
574-
model._mtp_state_dict = mtp_state_dict
575-
576569
return model
577570

578571

examples/llm_ptq/hf_ptq.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
get_tokenizer,
3232
is_enc_dec,
3333
is_nemotron_vl,
34-
load_mtp_weights_if_needed,
34+
load_mtp_weights,
3535
run_nemotron_vl_preview,
3636
)
3737
from torch.utils.data import DataLoader
@@ -349,17 +349,6 @@ def load_model(args: argparse.Namespace):
349349
)
350350
calibration_only = True
351351

352-
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
353-
# Store the MTP layer prefixes on the model for later exclusion from quantization
354-
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights_if_needed(
355-
full_model, args.pyt_ckpt_path
356-
)
357-
if mtp_layer_prefixes:
358-
full_model._mtp_layer_prefixes = mtp_layer_prefixes
359-
360-
if mtp_state_dict:
361-
full_model._mtp_state_dict = mtp_state_dict
362-
363352
model_type = get_model_type(full_model)
364353

365354
device = full_model.device
@@ -637,12 +626,17 @@ def export_quantized(
637626
"They will be set at deployment time."
638627
)
639628

629+
# Load any missing weights from non-standard safetensors (handled in get_model for non-low-memory mode)
630+
# Store the MTP layer prefixes on the model for later exclusion from quantization
631+
mtp_layer_prefixes, mtp_state_dict = load_mtp_weights(full_model, args.pyt_ckpt_path)
632+
633+
if mtp_layer_prefixes:
634+
full_model._mtp_layer_prefixes = mtp_layer_prefixes
635+
640636
export_hf_checkpoint(
641637
full_model,
642638
export_dir=export_path,
643-
extra_state_dict=full_model._mtp_state_dict
644-
if hasattr(full_model, "_mtp_state_dict")
645-
else None,
639+
extra_state_dict=mtp_state_dict,
646640
)
647641

648642
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

0 commit comments

Comments
 (0)