Skip to content

Commit aa77565

Browse files
committed
Support export to hf format
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent e3337a0 commit aa77565

5 files changed

Lines changed: 167 additions & 101 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -394,28 +394,11 @@ def load_model(args: argparse.Namespace):
394394
attn_implementation=args.attn_implementation,
395395
)
396396

397+
# Uncomment this to load the model from a .pt file
398+
# model = mto.restore(model, "./qwen3_omni_30b_nvfp4/model.pt")
399+
# print("Qwen3Omni model restored from checkpoint")
400+
397401
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
398-
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
399-
if "qkv_disabled" in args.qformat:
400-
# Disable q_proj, k_proj, v_proj quantizers
401-
for proj in ["q_proj", "k_proj", "v_proj"]:
402-
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
403-
"enable": False
404-
}
405-
if "qkvo_disabled" in args.qformat:
406-
# Disable q_proj, k_proj, v_proj, o_proj quantizers
407-
for proj in ["o_proj"]:
408-
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
409-
"enable": False
410-
}
411-
if "first_and_last_n_disabled" in args.qformat:
412-
# Disable both first N and last N layers
413-
total_layers = 48
414-
n_layers_to_disable = 4
415-
for i in range(n_layers_to_disable):
416-
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
417-
for i in range(total_layers - n_layers_to_disable, total_layers):
418-
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
419402
else:
420403
assert args.qformat in QUANT_CFG_CHOICES, (
421404
f"Quantization format is not supported for low memory mode. Supported formats: {QUANT_CFG_CHOICES.keys()}"
@@ -637,6 +620,37 @@ def mono_quantize(
637620
if language_model_lineage is not None:
638621
print("Updating full_model with quantized language_model...")
639622
language_model_lineage[-2].language_model = language_model
623+
624+
# Qwen3 specific quantizer disabling patterns (thinker.model.layers only)
625+
if "qkv_disabled" in args.qformat:
626+
# Disable q_proj, k_proj, v_proj quantizers
627+
for proj in ["q_proj", "k_proj", "v_proj"]:
628+
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
629+
"enable": False
630+
}
631+
if "qkvo_disabled" in args.qformat:
632+
# Disable q_proj, k_proj, v_proj, o_proj quantizers
633+
for proj in ["o_proj"]:
634+
quant_cfg["quant_cfg"][f"*thinker.model.layers.*.self_attn.{proj}*"] = {
635+
"enable": False
636+
}
637+
if "first_and_last_n_disabled" in args.qformat:
638+
# Disable both first N and last N layers
639+
total_layers = 48
640+
n_layers_to_disable = 4
641+
for i in range(n_layers_to_disable):
642+
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
643+
for i in range(total_layers - n_layers_to_disable, total_layers):
644+
quant_cfg["quant_cfg"][f"*thinker.model.layers.{i}.*"] = {"enable": False}
645+
646+
if not model_is_already_quantized or calibration_only:
647+
# Only run single sample for preview
648+
calib_batch = next(iter(calib_dataloader))
649+
input_ids = calib_batch["input_features" if model_type == "whisper" else "input_ids"][
650+
0:1
651+
]
652+
653+
# Generate preview before quantization
640654
if is_nemotron_vl_model and tokenizer is not None:
641655
generated_ids_before_ptq = run_nemotron_vl_preview(
642656
full_model,
@@ -771,11 +785,11 @@ def export_quantized(
771785
default_padding_side,
772786
default_pad_token,
773787
):
774-
if model_type == "qwen3omni":
775-
print("Export of Qwen3Omni model is not supported yet. Saving .pt file instead.")
776-
os.makedirs(os.path.dirname(args.export_path), exist_ok=True)
777-
mto.save(model, args.export_path)
778-
return
788+
# Uncomment this to save the model as a .pt file
789+
# if model_type == "qwen3omni":
790+
# print("Export of Qwen3Omni model is not supported yet. Saving .pt file instead.")
791+
# os.makedirs(os.path.dirname(args.export_path), exist_ok=True)
792+
# mto.save(full_model, f"{args.export_path}/model.pt")
779793

780794
with torch.inference_mode():
781795
if model_type is None:
@@ -857,6 +871,7 @@ def export_quantized(
857871
export_hf_checkpoint(
858872
full_model,
859873
export_dir=export_path,
874+
save_modelopt_state=model_type == "qwen3omni",
860875
)
861876

862877
# Copy custom model files (Python files and JSON configs) if trust_remote_code is used

examples/llm_ptq/run_quantized_qwen3omni.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1717
# SPDX-License-Identifier: Apache-2.0
1818

19-
"""Script to load and run a quantized Qwen3Omni model from mto checkpoint."""
19+
"""Script to load and run a quantized Qwen3Omni model from export_hf_checkpoint."""
2020

2121
import argparse
2222
import time
@@ -27,38 +27,41 @@
2727

2828
import modelopt.torch.opt as mto
2929

30+
# Enable HuggingFace checkpointing for modelopt quantized models
31+
mto.enable_huggingface_checkpointing()
32+
3033

3134
def main(args):
32-
print(f"Loading base model from {args.model_path}...")
35+
print(f"Loading quantized model from {args.checkpoint_path}...")
3336
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
34-
args.model_path,
37+
args.checkpoint_path,
3538
torch_dtype="auto",
36-
device_map="cuda",
39+
device_map="auto",
3740
attn_implementation="flash_attention_2",
3841
trust_remote_code=True,
3942
)
4043

41-
print(f"Restoring quantized state from {args.checkpoint_path}...")
42-
model = mto.restore(model, args.checkpoint_path)
43-
4444
model.disable_talker()
4545

4646
print("Loading processor...")
4747
processor = Qwen3OmniMoeProcessor.from_pretrained(
48-
args.model_path,
48+
"Qwen/Qwen3-Omni-30B-A3B-Thinking",
4949
trust_remote_code=True,
5050
)
5151

5252
# Build conversation with user prompt
5353
prompt = args.prompt or "What is the capital of France?"
54-
conversation = [{"role": "user", "content": [{"type": "text", "text": f"{prompt}"}]}]
54+
conversation = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
5555
conversations = [conversation]
5656

5757
# Set whether to use audio in video
5858
use_audio_in_video = True
5959

6060
# Preparation for inference
61-
texts = processor.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False)
61+
texts = processor.apply_chat_template(
62+
conversations, add_generation_prompt=True, tokenize=False, enable_thinking=False
63+
)
64+
print(f"Texts: {texts}")
6265
audios, images, videos = process_mm_info(conversations, use_audio_in_video=use_audio_in_video)
6366

6467
inputs = processor(
@@ -99,17 +102,11 @@ def main(args):
99102

100103
if __name__ == "__main__":
101104
parser = argparse.ArgumentParser(description="Run quantized Qwen3Omni model")
102-
parser.add_argument(
103-
"--model_path",
104-
type=str,
105-
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
106-
help="Path to the base Qwen3Omni model (HF format)",
107-
)
108105
parser.add_argument(
109106
"--checkpoint_path",
110107
type=str,
111-
default="/home/scratch.arasane_hw/models/qwen3omni_nvfp4_qkv_disabled_text_bs512_calib512.pt",
112-
help="Path to the mto.save() quantized checkpoint",
108+
required=True,
109+
help="Path to the export_hf_checkpoint() quantized checkpoint directory",
113110
)
114111
parser.add_argument(
115112
"--prompt",

modelopt/torch/export/unified_export_hf.py

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -300,29 +300,43 @@ def llm_dummy_forward():
300300
[1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype
301301
).to(model.device)
302302

303-
if getattr(model.config, "is_encoder_decoder", False):
304-
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
305-
model(fake_input, decoder_input_ids=decoder_fake_input)
306-
elif is_vl_model and "nemotron" in model_type:
307-
# For Nemotron VL models, try to run optimization on just the language model part
308-
language_model_lineage = get_language_model_from_vl(model)
309-
310-
if language_model_lineage is not None:
311-
# Run optimization on just the language model with the same input format as regular LLMs
312-
# Use the same fake_input tensor that regular LLMs use
313-
language_model = language_model_lineage[-1]
314-
print(
315-
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
316-
)
317-
language_model(fake_input)
303+
with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}):
304+
if getattr(model.config, "is_encoder_decoder", False):
305+
# For encoder-decoder models, we need to pass both the encoder and decoder input ids
306+
model(fake_input, decoder_input_ids=decoder_fake_input)
307+
elif is_vl_model and "nemotron" in model_type:
308+
# For Nemotron VL models, try to run optimization on just the language model part
309+
language_model_lineage = get_language_model_from_vl(model)
310+
311+
if language_model_lineage is not None:
312+
# Run optimization on just the language model with the same input format as regular LLMs
313+
# Use the same fake_input tensor that regular LLMs use
314+
language_model = language_model_lineage[-1]
315+
print(
316+
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
317+
)
318+
language_model(fake_input)
319+
else:
320+
raise ValueError(
321+
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
322+
"This is required for requantization/resmoothing optimization. "
323+
"Please ensure the model architecture is supported or file an issue."
324+
)
325+
elif "qwen3omni" in model_type:
326+
# For Qwen3Omni, run on the thinker (language model) component
327+
# The model has structure: model.thinker.model.layers.*
328+
if hasattr(model, "thinker"):
329+
print(
330+
f"Running optimization on Qwen3Omni thinker with fake_input shape: {fake_input.shape}"
331+
)
332+
model.thinker(fake_input)
333+
else:
334+
raise ValueError(
335+
f"Cannot extract thinker from Qwen3Omni model (type: {model_type}). "
336+
"This is required for requantization/resmoothing optimization."
337+
)
318338
else:
319-
raise ValueError(
320-
f"Cannot extract language_model from Nemotron VL model (type: {model_type}). "
321-
"This is required for requantization/resmoothing optimization. "
322-
"Please ensure the model architecture is supported or file an issue."
323-
)
324-
else:
325-
model(fake_input)
339+
model(fake_input)
326340

327341
input_to_linear, output_to_layernorm = _collect_shared_input_modules(
328342
model, llm_dummy_forward, collect_layernorms=True
@@ -380,6 +394,19 @@ def _export_quantized_weight(
380394
weight_quantizer: TensorQuantizer | SequentialQuantizer = getattr(
381395
sub_module, quantizer_attrs.weight_quantizer
382396
)
397+
398+
# Skip export if weight quantizer is disabled or has no amax (not calibrated)
399+
if not _is_enabled_quantizer(weight_quantizer):
400+
return
401+
402+
# Check if weight quantizer has calibrated amax
403+
def _has_amax(quantizer):
404+
if isinstance(quantizer, SequentialQuantizer):
405+
return any(hasattr(q, "_amax") and q._amax is not None for q in quantizer)
406+
return hasattr(quantizer, "_amax") and quantizer._amax is not None
407+
408+
if not _has_amax(weight_quantizer):
409+
return
383410
input_quantizer: TensorQuantizer | SequentialQuantizer | None = getattr(
384411
sub_module, quantizer_attrs.input_quantizer, None
385412
)
@@ -543,6 +570,7 @@ def _process_quantized_modules(
543570
model: nn.Module,
544571
dtype: torch.dtype,
545572
is_modelopt_qlora: bool = False,
573+
pack_weights: bool = True,
546574
) -> None:
547575
"""Process all quantized modules in model, export weights in-place.
548576
@@ -555,6 +583,7 @@ def _process_quantized_modules(
555583
dtype: The data type for weight conversion.
556584
is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.
557585
If True, modules with base_layer attribute are skipped.
586+
pack_weights: Whether to pack quantized weights.
558587
"""
559588
fsdp_module_to_reshard = None
560589

@@ -577,8 +606,9 @@ def _process_quantized_modules(
577606
sub_module.unpack_weight()
578607
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
579608
if is_quantlinear(sub_module):
580-
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
581-
_export_quantized_weight(sub_module, dtype)
609+
if pack_weights:
610+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
611+
_export_quantized_weight(sub_module, dtype)
582612
elif (
583613
"Llama4TextExperts" in type(sub_module).__name__
584614
or "GptOssExperts" in type(sub_module).__name__
@@ -595,13 +625,18 @@ def _process_quantized_modules(
595625
quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"],
596626
)
597627
# Export the quantized weights
598-
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
599-
for weight_name in ["gate_up_proj", "down_proj"]:
600-
_export_quantized_weight(sub_module, dtype, weight_name)
628+
if pack_weights:
629+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
630+
for weight_name in ["gate_up_proj", "down_proj"]:
631+
_export_quantized_weight(sub_module, dtype, weight_name)
601632

602633

603-
def _export_transformers_checkpoint(
604-
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
634+
def _export_hf_checkpoint(
635+
model: nn.Module,
636+
dtype: torch.dtype | None = None,
637+
is_modelopt_qlora: bool = False,
638+
pack_weights: bool = True,
639+
**kwargs,
605640
) -> tuple[dict[str, Any], dict[str, Any]]:
606641
"""Exports the torch model to the packed checkpoint with original HF naming.
607642
@@ -611,6 +646,7 @@ def _export_transformers_checkpoint(
611646
model: the full torch model to export. The actual quantized model may be a submodule.
612647
dtype: the weights data type to export the unquantized layers or the default model data type if None.
613648
accelerator: the accelerator instance in case of distributed export setup.
649+
pack_weights: whether to pack quantized weights (False keeps original shapes for HF reload).
614650
615651
Returns:
616652
post_state_dict: Dict containing quantized weights
@@ -695,7 +731,7 @@ def _export_transformers_checkpoint(
695731
quant_config = get_quant_config(model, is_modelopt_qlora=is_modelopt_qlora)
696732

697733
# Process all quantized modules and export weights
698-
_process_quantized_modules(model, dtype, is_modelopt_qlora)
734+
_process_quantized_modules(model, dtype, is_modelopt_qlora, pack_weights)
699735

700736
if accelerator is not None:
701737
# Gather state_dict from all ranks
@@ -964,7 +1000,12 @@ def export_hf_checkpoint(
9641000
return
9651001

9661002
try:
967-
post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)
1003+
# Packed weights are only for TRT-LLM consumption
1004+
# Set this to true if you want to save the weights in the original precision
1005+
pack_weights = True
1006+
post_state_dict, hf_quant_config = _export_hf_checkpoint(
1007+
model, dtype, pack_weights=pack_weights
1008+
)
9681009

9691010
if hf_quant_config is not None:
9701011
# Save hf_quant_config.json for backward compatibility
@@ -977,6 +1018,16 @@ def export_hf_checkpoint(
9771018
if getattr(model, "hf_quantizer", None) is not None:
9781019
model.hf_quantizer = None
9791020

1021+
# Fix generation_config conflicts before saving
1022+
# Some models have temperature/top_p/top_k set but do_sample=False which causes validation errors
1023+
if hasattr(model, "generation_config") and model.generation_config is not None:
1024+
gen_config = model.generation_config
1025+
if not getattr(gen_config, "do_sample", True):
1026+
# Remove sampling-related params when do_sample is False
1027+
for attr in ["temperature", "top_p", "top_k"]:
1028+
if hasattr(gen_config, attr):
1029+
setattr(gen_config, attr, None)
1030+
9801031
# Save model
9811032
model.save_pretrained(
9821033
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state

modelopt/torch/utils/image_processor.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,9 @@ def preprocess_function(self, text: str) -> dict:
145145
Dictionary with tokenized inputs.
146146
"""
147147
# Build conversation in Qwen format (text-only)
148-
conversation = [
149-
{"role": "user", "content": [{"type": "text", "text": "/no_think " + text}]}
150-
]
151-
152-
# Apply chat template (tokenize=False to get formatted string)
148+
conversation = [{"role": "user", "content": [{"type": "text", "text": text}]}]
153149
formatted_text = self.tokenizer.apply_chat_template(
154-
conversation, add_generation_prompt=True, tokenize=False
150+
conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False
155151
)
156152

157153
# Tokenize with the processor (no multimodal inputs)
@@ -212,10 +208,8 @@ def preprocess_function(self, examples):
212208
content.append({"type": "text", "text": question})
213209

214210
conversation = [{"role": "user", "content": content}]
215-
216-
# Apply chat template (tokenize=False to get string)
217211
text = self.tokenizer.apply_chat_template(
218-
conversation, add_generation_prompt=True, tokenize=False
212+
conversation, add_generation_prompt=True, tokenize=False, enable_thinking=False
219213
)
220214

221215
# Extract multimodal info using qwen_omni_utils

0 commit comments

Comments
 (0)