Skip to content

Commit e287c0b

Browse files
committed
Refactor model specific code to example_utils
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent dfedafa commit e287c0b

6 files changed

Lines changed: 184 additions & 113 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,21 @@
4242
snapshot_download = None
4343

4444
import modelopt.torch.quantization as mtq
45+
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
4546
from modelopt.torch.utils.image_processor import (
4647
BaseImageProcessor,
4748
MllamaImageProcessor,
4849
Qwen3OmniImageProcessor,
4950
)
51+
from modelopt.torch.utils.video_dataset_utils import (
52+
Qwen3OmniVideoProcessor,
53+
get_supported_video_datasets,
54+
get_video_dataset_dataloader,
55+
)
56+
from modelopt.torch.utils.vlm_dataset_utils import (
57+
get_supported_vlm_datasets,
58+
get_vlm_dataset_dataloader,
59+
)
5060

5161
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
5262

@@ -244,9 +254,33 @@ def build_quant_cfg(
244254
quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False}
245255
quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False}
246256

257+
if model_type == "qwen3omni":
258+
print(
259+
"Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
260+
)
261+
quant_cfg["quant_cfg"]["*conv*"] = {"enable": False}
262+
quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False}
263+
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
264+
247265
return quant_cfg
248266

249267

268+
def get_generation_kwargs(model_type: str) -> dict[str, Any]:
269+
"""Get model-specific generation kwargs for calibration.
270+
271+
Args:
272+
model_type: The model type string.
273+
274+
Returns:
275+
Dictionary of generation kwargs for the model.
276+
"""
277+
generation_kwargs = {}
278+
if model_type == "qwen3omni":
279+
generation_kwargs["return_audio"] = False
280+
generation_kwargs["thinker_max_new_tokens"] = 1
281+
return generation_kwargs
282+
283+
250284
def is_speculative(hf_config):
251285
"""Check if the model architecture is a speculative model."""
252286
return hf_config.architectures and any(
@@ -617,3 +651,81 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
617651
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
618652
else:
619653
print("No custom model files found to copy")
654+
655+
656+
def get_qwen3omni_dataloader(
657+
dataset_name: str | list[str] | None,
658+
processor: Qwen3OmniImageProcessor | None,
659+
tokenizer,
660+
batch_size: int,
661+
num_samples: int | list[int],
662+
device: torch.device,
663+
model_dtype: torch.dtype,
664+
include_labels: bool = False,
665+
):
666+
"""Create a calibration dataloader for Qwen3Omni models.
667+
668+
Handles video, VLM, and text-only dataset configurations.
669+
670+
Args:
671+
dataset_name: Name of the dataset(s) to use for calibration.
672+
processor: The Qwen3OmniImageProcessor for multimodal inputs.
673+
tokenizer: The tokenizer for text-only fallback.
674+
batch_size: Batch size for the dataloader.
675+
num_samples: Number of samples to use (int or list for multi-dataset).
676+
device: Target device for tensors.
677+
model_dtype: Model dtype for proper tensor conversion.
678+
include_labels: Whether to include labels (for gradient-based auto_quantize).
679+
680+
Returns:
681+
DataLoader for calibration.
682+
"""
683+
if dataset_name is None:
684+
dataset_name = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
685+
686+
if processor is not None:
687+
if dataset_name in get_supported_video_datasets():
688+
assert isinstance(dataset_name, str)
689+
video_processor = Qwen3OmniVideoProcessor(
690+
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
691+
device=device,
692+
dtype=model_dtype,
693+
use_audio_in_video=True,
694+
)
695+
calib_dataloader = get_video_dataset_dataloader(
696+
dataset_name=dataset_name,
697+
processor=video_processor,
698+
batch_size=batch_size,
699+
num_samples=num_samples if isinstance(num_samples, int) else num_samples[0],
700+
)
701+
elif dataset_name in get_supported_vlm_datasets():
702+
assert isinstance(dataset_name, str)
703+
assert isinstance(processor, Qwen3OmniImageProcessor), (
704+
"The Qwen3OmniImageProcessor must be set."
705+
)
706+
# Set the dtype for proper tensor conversion in collate_function
707+
processor.dtype = model_dtype
708+
calib_dataloader = get_vlm_dataset_dataloader(
709+
dataset_name=dataset_name,
710+
processor=processor,
711+
batch_size=batch_size,
712+
num_samples=num_samples if isinstance(num_samples, int) else num_samples[0],
713+
)
714+
else:
715+
raise ValueError(
716+
f"Dataset '{dataset_name}' not supported for Qwen3Omni with processor. "
717+
f"Supported video datasets: {get_supported_video_datasets()}, "
718+
f"Supported VLM datasets: {get_supported_vlm_datasets()}"
719+
)
720+
else:
721+
# Text-only fallback
722+
calib_dataloader = get_dataset_dataloader(
723+
dataset_name=dataset_name if isinstance(dataset_name, list) else [dataset_name],
724+
tokenizer=tokenizer,
725+
batch_size=batch_size,
726+
num_samples=num_samples if isinstance(num_samples, list) else [num_samples],
727+
device=device,
728+
include_labels=include_labels,
729+
)
730+
731+
return calib_dataloader

examples/llm_ptq/hf_ptq.py

Lines changed: 20 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
# limitations under the License.
1515

1616
import argparse
17-
import io
1817
import random
19-
import sys
2018
import time
2119
import warnings
2220
from typing import Any
@@ -28,8 +26,10 @@
2826
build_quant_cfg,
2927
copy_custom_model_files,
3028
create_vlm_calibration_loop,
29+
get_generation_kwargs,
3130
get_model,
3231
get_processor,
32+
get_qwen3omni_dataloader,
3333
get_tokenizer,
3434
is_enc_dec,
3535
is_nemotron_vl,
@@ -72,15 +72,7 @@
7272
)
7373
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
7474
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
75-
from modelopt.torch.utils.video_dataset_utils import (
76-
Qwen3OmniVideoProcessor,
77-
get_supported_video_datasets,
78-
get_video_dataset_dataloader,
79-
)
80-
from modelopt.torch.utils.vlm_dataset_utils import (
81-
get_supported_vlm_datasets,
82-
get_vlm_dataset_dataloader,
83-
)
75+
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader
8476

8577
RAND_SEED = 1234
8678

@@ -194,47 +186,20 @@ def make_calib_dataloader(
194186
num_samples=args.calib_size[0],
195187
)
196188
elif model_type == "qwen3omni":
197-
dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail"
198-
# Check if using video dataset (e.g., finevideo)
199-
if processor is not None:
200-
if dataset_name in get_supported_video_datasets():
201-
video_processor = Qwen3OmniVideoProcessor(
202-
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
203-
device=device,
204-
dtype=language_model.dtype,
205-
use_audio_in_video=True,
206-
)
207-
calib_dataloader = get_video_dataset_dataloader(
208-
dataset_name=dataset_name,
209-
processor=video_processor,
210-
batch_size=args.batch_size,
211-
num_samples=args.calib_size[0],
212-
)
213-
elif dataset_name in get_supported_vlm_datasets():
214-
assert isinstance(processor, Qwen3OmniImageProcessor), (
215-
"The Qwen3OmniImageProcessor must be set."
216-
)
217-
# Set the dtype for proper tensor conversion in collate_function
218-
processor.dtype = language_model.dtype
219-
calib_dataloader = get_vlm_dataset_dataloader(
220-
dataset_name=dataset_name,
221-
processor=processor,
222-
batch_size=args.batch_size,
223-
num_samples=args.calib_size[0],
224-
)
225-
else:
226-
# Labels are only needed for gradient-based auto_quantize
227-
include_labels = (
228-
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
229-
)
230-
calib_dataloader = get_dataset_dataloader(
231-
dataset_name=args.dataset,
232-
tokenizer=tokenizer,
233-
batch_size=args.batch_size,
234-
num_samples=args.calib_size,
235-
device=device,
236-
include_labels=include_labels,
237-
)
189+
# Labels are only needed for gradient-based auto_quantize
190+
include_labels = (
191+
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
192+
)
193+
calib_dataloader = get_qwen3omni_dataloader(
194+
dataset_name=args.dataset[0] if args.dataset else None,
195+
processor=processor,
196+
tokenizer=tokenizer,
197+
batch_size=args.batch_size,
198+
num_samples=args.calib_size[0] if processor else args.calib_size,
199+
device=device,
200+
model_dtype=language_model.dtype,
201+
include_labels=include_labels,
202+
)
238203
elif model_type == "whisper":
239204
assert processor is not None and isinstance(processor, WhisperProcessor), (
240205
"The AutoProcessor must be set."
@@ -566,17 +531,8 @@ def mono_quantize(
566531
quant_cfg["quant_cfg"]["*radio*"] = {"enable": False}
567532
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
568533

569-
# For Qwen3Omni models, disable quantization of conv layers
570-
generation_kwargs = {}
571-
if model_type == "qwen3omni":
572-
print(
573-
"Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
574-
)
575-
quant_cfg["quant_cfg"]["*conv*"] = {"enable": False}
576-
quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False}
577-
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}
578-
generation_kwargs["return_audio"] = False
579-
generation_kwargs["thinker_max_new_tokens"] = 1
534+
# Get model-specific generation kwargs (e.g., for Qwen3Omni)
535+
generation_kwargs = get_generation_kwargs(model_type)
580536

581537
if not model_is_already_quantized or calibration_only:
582538
if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only":
@@ -799,20 +755,7 @@ def post_quantize(
799755
"""
800756

801757
if args.verbose:
802-
if args.quant_summary_path:
803-
# Capture the summary output to a file
804-
old_stdout = sys.stdout
805-
sys.stdout = buffer = io.StringIO()
806-
try:
807-
mtq.print_quant_summary(full_model)
808-
finally:
809-
sys.stdout = old_stdout
810-
summary = buffer.getvalue()
811-
with open(args.quant_summary_path, "w") as f:
812-
f.write(summary)
813-
print(f"Quantization summary saved to {args.quant_summary_path}")
814-
else:
815-
mtq.print_quant_summary(full_model)
758+
mtq.print_quant_summary(full_model, save_path=args.quant_summary_path)
816759

817760
# Run some samples
818761
torch.cuda.empty_cache()

modelopt/torch/export/model_utils.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,45 +17,45 @@
1717
import torch.nn as nn
1818

1919
MODEL_NAME_TO_TYPE = {
20-
"ArcticForCausalLM": "llama",
21-
"baichuan": "baichuan",
22-
"Bart": "bart",
23-
"Bloom": "bloom",
24-
"ChatGLM": "chatglm",
25-
"Dbrx": "dbrx",
26-
"Deepseek": "deepseek",
27-
"ExaoneForCausalLM": "exaone",
28-
"FalconForCausalLM": "falcon",
29-
"Gemma": "gemma",
30-
"Gemma2": "gemma2",
31-
"Gemma3": "gemma3",
32-
"GLM": "glm",
3320
"GPT2": "gpt",
34-
"GPTJ": "gptj",
35-
"gptoss": "gptoss",
36-
"InternLM2ForCausalLM": "internlm",
37-
"Llama": "llama",
21+
"Mllama": "mllama",
3822
"Llama4": "llama4",
23+
"Llama": "llama",
3924
"Mistral": "llama",
40-
"MixtralForCausalLM": "llama",
41-
"Mllama": "mllama",
25+
"GPTJ": "gptj",
26+
"FalconForCausalLM": "falcon",
27+
"RWForCausalLM": "falcon",
28+
"baichuan": "baichuan",
4229
"MPT": "mpt",
43-
"Nemotron": "gpt",
44-
"phi": "phi",
45-
"phi3": "phi3",
46-
"phi3small": "phi3small",
47-
"Phi4MMForCausalLM": "phi4mm",
48-
"PhiMoEForCausalLM": "phi3",
30+
"Bloom": "bloom",
31+
"ChatGLM": "chatglm",
4932
"Qwen3Moe": "qwen3moe",
5033
"Qwen3Next": "qwen3next",
5134
"Qwen3OmniMoeForConditionalGeneration": "qwen3omni",
5235
"QWen": "qwen",
5336
"RecurrentGemma": "recurrentgemma",
54-
"RWForCausalLM": "falcon",
37+
"Gemma3": "gemma3",
38+
"Gemma2": "gemma2",
39+
"Gemma": "gemma",
40+
"phi3small": "phi3small",
41+
"phi3": "phi3",
42+
"PhiMoEForCausalLM": "phi3",
43+
"Phi4MMForCausalLM": "phi4mm",
44+
"phi": "phi",
45+
"TLGv4ForCausalLM": "phi",
46+
"MixtralForCausalLM": "llama",
47+
"ArcticForCausalLM": "llama",
5548
"StarCoder": "gpt",
49+
"Dbrx": "dbrx",
5650
"T5": "t5",
57-
"TLGv4ForCausalLM": "phi",
51+
"Bart": "bart",
52+
"GLM": "glm",
53+
"InternLM2ForCausalLM": "internlm",
54+
"ExaoneForCausalLM": "exaone",
55+
"Nemotron": "gpt",
56+
"Deepseek": "deepseek",
5857
"Whisper": "whisper",
58+
"gptoss": "gptoss",
5959
}
6060

6161
__doc__ = f"""Utility functions for model type detection and classification.

modelopt/torch/export/unified_export_hf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,10 +1003,12 @@ def export_hf_checkpoint(
10031003
if hasattr(model, "generation_config") and model.generation_config is not None:
10041004
gen_config = model.generation_config
10051005
if not getattr(gen_config, "do_sample", True):
1006-
# Remove sampling-related params when do_sample is False
1007-
for attr in ["temperature", "top_p", "top_k"]:
1008-
if hasattr(gen_config, attr):
1009-
setattr(gen_config, attr, None)
1006+
# Enable sampling if sampling params are present
1007+
if any(
1008+
getattr(gen_config, attr, None) is not None
1009+
for attr in ["temperature", "top_p", "top_k"]
1010+
):
1011+
gen_config.do_sample = True
10101012

10111013
# Save model
10121014
model.save_pretrained(

0 commit comments

Comments
 (0)