|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import argparse |
17 | | -import io |
18 | 17 | import random |
19 | | -import sys |
20 | 18 | import time |
21 | 19 | import warnings |
22 | 20 | from typing import Any |
|
28 | 26 | build_quant_cfg, |
29 | 27 | copy_custom_model_files, |
30 | 28 | create_vlm_calibration_loop, |
| 29 | + get_generation_kwargs, |
31 | 30 | get_model, |
32 | 31 | get_processor, |
| 32 | + get_qwen3omni_dataloader, |
33 | 33 | get_tokenizer, |
34 | 34 | is_enc_dec, |
35 | 35 | is_nemotron_vl, |
|
72 | 72 | ) |
73 | 73 | from modelopt.torch.utils.memory_monitor import launch_memory_monitor |
74 | 74 | from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader |
75 | | -from modelopt.torch.utils.video_dataset_utils import ( |
76 | | - Qwen3OmniVideoProcessor, |
77 | | - get_supported_video_datasets, |
78 | | - get_video_dataset_dataloader, |
79 | | -) |
80 | | -from modelopt.torch.utils.vlm_dataset_utils import ( |
81 | | - get_supported_vlm_datasets, |
82 | | - get_vlm_dataset_dataloader, |
83 | | -) |
| 75 | +from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader |
84 | 76 |
|
85 | 77 | RAND_SEED = 1234 |
86 | 78 |
|
@@ -194,47 +186,20 @@ def make_calib_dataloader( |
194 | 186 | num_samples=args.calib_size[0], |
195 | 187 | ) |
196 | 188 | elif model_type == "qwen3omni": |
197 | | - dataset_name = args.dataset[0] if args.dataset else "cnn_dailymail" |
198 | | - # Check if using video dataset (e.g., finevideo) |
199 | | - if processor is not None: |
200 | | - if dataset_name in get_supported_video_datasets(): |
201 | | - video_processor = Qwen3OmniVideoProcessor( |
202 | | - processor.tokenizer if hasattr(processor, "tokenizer") else processor, |
203 | | - device=device, |
204 | | - dtype=language_model.dtype, |
205 | | - use_audio_in_video=True, |
206 | | - ) |
207 | | - calib_dataloader = get_video_dataset_dataloader( |
208 | | - dataset_name=dataset_name, |
209 | | - processor=video_processor, |
210 | | - batch_size=args.batch_size, |
211 | | - num_samples=args.calib_size[0], |
212 | | - ) |
213 | | - elif dataset_name in get_supported_vlm_datasets(): |
214 | | - assert isinstance(processor, Qwen3OmniImageProcessor), ( |
215 | | - "The Qwen3OmniImageProcessor must be set." |
216 | | - ) |
217 | | - # Set the dtype for proper tensor conversion in collate_function |
218 | | - processor.dtype = language_model.dtype |
219 | | - calib_dataloader = get_vlm_dataset_dataloader( |
220 | | - dataset_name=dataset_name, |
221 | | - processor=processor, |
222 | | - batch_size=args.batch_size, |
223 | | - num_samples=args.calib_size[0], |
224 | | - ) |
225 | | - else: |
226 | | - # Labels are only needed for gradient-based auto_quantize |
227 | | - include_labels = ( |
228 | | - args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" |
229 | | - ) |
230 | | - calib_dataloader = get_dataset_dataloader( |
231 | | - dataset_name=args.dataset, |
232 | | - tokenizer=tokenizer, |
233 | | - batch_size=args.batch_size, |
234 | | - num_samples=args.calib_size, |
235 | | - device=device, |
236 | | - include_labels=include_labels, |
237 | | - ) |
| 189 | + # Labels are only needed for gradient-based auto_quantize |
| 190 | + include_labels = ( |
| 191 | + args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" |
| 192 | + ) |
| 193 | + calib_dataloader = get_qwen3omni_dataloader( |
| 194 | + dataset_name=args.dataset[0] if args.dataset else None, |
| 195 | + processor=processor, |
| 196 | + tokenizer=tokenizer, |
| 197 | + batch_size=args.batch_size, |
| 198 | + num_samples=args.calib_size[0] if processor else args.calib_size, |
| 199 | + device=device, |
| 200 | + model_dtype=language_model.dtype, |
| 201 | + include_labels=include_labels, |
| 202 | + ) |
238 | 203 | elif model_type == "whisper": |
239 | 204 | assert processor is not None and isinstance(processor, WhisperProcessor), ( |
240 | 205 | "The AutoProcessor must be set." |
@@ -566,17 +531,8 @@ def mono_quantize( |
566 | 531 | quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} |
567 | 532 | quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} |
568 | 533 |
|
569 | | - # For Qwen3Omni models, disable quantization of conv layers |
570 | | - generation_kwargs = {} |
571 | | - if model_type == "qwen3omni": |
572 | | - print( |
573 | | - "Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model" |
574 | | - ) |
575 | | - quant_cfg["quant_cfg"]["*conv*"] = {"enable": False} |
576 | | - quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False} |
577 | | - quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} |
578 | | - generation_kwargs["return_audio"] = False |
579 | | - generation_kwargs["thinker_max_new_tokens"] = 1 |
| 534 | + # Get model-specific generation kwargs (e.g., for Qwen3Omni) |
| 535 | + generation_kwargs = get_generation_kwargs(model_type) |
580 | 536 |
|
581 | 537 | if not model_is_already_quantized or calibration_only: |
582 | 538 | if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": |
@@ -799,20 +755,7 @@ def post_quantize( |
799 | 755 | """ |
800 | 756 |
|
801 | 757 | if args.verbose: |
802 | | - if args.quant_summary_path: |
803 | | - # Capture the summary output to a file |
804 | | - old_stdout = sys.stdout |
805 | | - sys.stdout = buffer = io.StringIO() |
806 | | - try: |
807 | | - mtq.print_quant_summary(full_model) |
808 | | - finally: |
809 | | - sys.stdout = old_stdout |
810 | | - summary = buffer.getvalue() |
811 | | - with open(args.quant_summary_path, "w") as f: |
812 | | - f.write(summary) |
813 | | - print(f"Quantization summary saved to {args.quant_summary_path}") |
814 | | - else: |
815 | | - mtq.print_quant_summary(full_model) |
| 758 | + mtq.print_quant_summary(full_model, save_path=args.quant_summary_path) |
816 | 759 |
|
817 | 760 | # Run some samples |
818 | 761 | torch.cuda.empty_cache() |
|
0 commit comments