1414# limitations under the License.
1515
1616import argparse
17+ import io
1718import random
19+ import sys
1820import time
1921import warnings
2022from typing import Any
6163 create_forward_loop ,
6264 get_dataset_dataloader ,
6365 get_max_batch_size ,
66+ get_qwen3omni_text_dataloader ,
6467 get_supported_datasets ,
6568)
66- from modelopt .torch .utils .image_processor import BaseImageProcessor , MllamaImageProcessor
69+ from modelopt .torch .utils .image_processor import (
70+ BaseImageProcessor ,
71+ MllamaImageProcessor ,
72+ Qwen3OmniImageProcessor ,
73+ Qwen3OmniTextProcessor ,
74+ )
6775from modelopt .torch .utils .memory_monitor import launch_memory_monitor
6876from modelopt .torch .utils .speech_dataset_utils import get_speech_dataset_dataloader
69- from modelopt .torch .utils .vlm_dataset_utils import get_vlm_dataset_dataloader
77+ from modelopt .torch .utils .video_dataset_utils import (
78+ Qwen3OmniVideoProcessor ,
79+ get_supported_video_datasets ,
80+ get_video_dataset_dataloader ,
81+ )
82+ from modelopt .torch .utils .vlm_dataset_utils import (
83+ get_supported_vlm_datasets ,
84+ get_vlm_dataset_dataloader ,
85+ )
7086
7187RAND_SEED = 1234
7288
@@ -179,6 +195,51 @@ def make_calib_dataloader(
179195 batch_size = args .batch_size ,
180196 num_samples = args .calib_size [0 ],
181197 )
198+ elif model_type == "qwen3omni" :
199+ assert processor is not None , "The processor must be set for qwen3omni model."
200+ dataset_name = args .dataset [0 ] if args .dataset else "cnn_dailymail"
201+ # Check if using video dataset (e.g., finevideo)
202+ if dataset_name in get_supported_video_datasets ():
203+ video_processor = Qwen3OmniVideoProcessor (
204+ processor .tokenizer if hasattr (processor , "tokenizer" ) else processor ,
205+ device = device ,
206+ dtype = language_model .dtype ,
207+ use_audio_in_video = True ,
208+ )
209+ calib_dataloader = get_video_dataset_dataloader (
210+ dataset_name = dataset_name ,
211+ processor = video_processor ,
212+ batch_size = args .batch_size ,
213+ num_samples = args .calib_size [0 ],
214+ )
215+ elif dataset_name in get_supported_vlm_datasets ():
216+ assert isinstance (processor , Qwen3OmniImageProcessor ), (
217+ "The Qwen3OmniImageProcessor must be set."
218+ )
219+ # Set the dtype for proper tensor conversion in collate_function
220+ processor .dtype = language_model .dtype
221+ calib_dataloader = get_vlm_dataset_dataloader (
222+ dataset_name = dataset_name ,
223+ processor = processor ,
224+ batch_size = args .batch_size ,
225+ num_samples = args .calib_size [0 ],
226+ )
227+ else :
228+ # Text-only datasets (e.g., cnn_dailymail)
229+ # Use Qwen3OmniTextProcessor to apply proper conversation template
230+ # See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking
231+ text_processor = Qwen3OmniTextProcessor (
232+ processor = processor .tokenizer , # Pass the underlying HF processor
233+ device = device ,
234+ dtype = language_model .dtype ,
235+ )
236+ calib_dataloader = get_qwen3omni_text_dataloader (
237+ dataset_name = dataset_name ,
238+ processor = text_processor ,
239+ batch_size = args .batch_size ,
240+ num_samples = args .calib_size [0 ],
241+ )
242+ print (f"Selected dataset for calibration: { dataset_name } " )
182243 elif model_type == "whisper" :
183244 assert processor is not None and isinstance (processor , WhisperProcessor ), (
184245 "The AutoProcessor must be set."
@@ -349,6 +410,9 @@ def load_model(args: argparse.Namespace):
349410 calibration_only = True
350411
351412 model_type = get_model_type (full_model )
413+ if model_type == "qwen3omni" :
414+ print ("Disabling talker for Qwen3Omni model" )
415+ full_model .disable_talker ()
352416
353417 device = full_model .device
354418 if hasattr (full_model , "model" ):
@@ -360,7 +424,7 @@ def load_model(args: argparse.Namespace):
360424 default_pad_token = None
361425
362426 is_nemotron_vl_model = is_nemotron_vl (full_model )
363- if model_type == "mllama" :
427+ if model_type in [ "mllama" , "qwen3omni" ] :
364428 processor = get_processor (
365429 args .pyt_ckpt_path ,
366430 model_type ,
@@ -502,6 +566,15 @@ def mono_quantize(
502566 quant_cfg ["quant_cfg" ]["*radio*" ] = {"enable" : False }
503567 quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
504568
569+ # For Qwen3Omni models, disable quantization of conv layers
570+ if model_type == "qwen3omni" :
571+ print (
572+ "Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
573+ )
574+ quant_cfg ["quant_cfg" ]["*conv*" ] = {"enable" : False }
575+ quant_cfg ["quant_cfg" ]["*audio_tower*" ] = {"enable" : False }
576+ quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
577+
505578 if not model_is_already_quantized or calibration_only :
506579 if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
507580 print ("Applying nvfp4 quantization (MoE only) for gpt-oss" )
@@ -662,9 +735,10 @@ def pre_quantize(
662735
663736 """
664737 # Only run single sample for preview
665- preview_input_ids = next (iter (calib_dataloader ))[
666- "input_features" if model_type == "whisper" else "input_ids"
667- ][0 :1 ]
738+ calib_batch = next (iter (calib_dataloader ))
739+ preview_input_ids = calib_batch ["input_features" if model_type == "whisper" else "input_ids" ][
740+ 0 :1
741+ ]
668742
669743 # Generate preview before quantization
670744 if model_type == "deepseek" :
@@ -679,13 +753,24 @@ def pre_quantize(
679753 "before quantization" ,
680754 allow_fallback = True ,
681755 )
756+ elif model_type == "qwen3omni" :
757+ # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
758+ # Pass full batch with all multimodal inputs
759+ result = full_model .generate (** calib_batch , max_new_tokens = 100 )
760+ if isinstance (result , tuple ):
761+ text_ids , _ = result
762+ generated_ids_before_ptq = (
763+ text_ids .sequences if hasattr (text_ids , "sequences" ) else text_ids
764+ )
765+ else :
766+ generated_ids_before_ptq = result
682767 else :
683768 # Standard generation for non-Nemotron VL models
684769 generated_ids_before_ptq = full_model .generate (preview_input_ids , max_new_tokens = 100 )
685770 if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
686771 print ("Applying nvfp4 quantization (MoE only) for gpt-oss" )
687772
688- return preview_input_ids , generated_ids_before_ptq
773+ return preview_input_ids , generated_ids_before_ptq , calib_batch
689774
690775
691776def post_quantize (
@@ -698,6 +783,7 @@ def post_quantize(
698783 generated_ids_before_ptq ,
699784 is_nemotron_vl_model ,
700785 first_text_speech_dataset ,
786+ calib_batch : dict | None = None ,
701787):
702788 """
703789 Processing after the quantization.
@@ -708,13 +794,37 @@ def post_quantize(
708794 """
709795
710796 if args .verbose :
711- mtq .print_quant_summary (full_model )
797+ if args .quant_summary_path :
798+ # Capture the summary output to a file
799+ old_stdout = sys .stdout
800+ sys .stdout = buffer = io .StringIO ()
801+ try :
802+ mtq .print_quant_summary (full_model )
803+ finally :
804+ sys .stdout = old_stdout
805+ summary = buffer .getvalue ()
806+ with open (args .quant_summary_path , "w" ) as f :
807+ f .write (summary )
808+ print (f"Quantization summary saved to { args .quant_summary_path } " )
809+ else :
810+ mtq .print_quant_summary (full_model )
712811
713812 # Run some samples
714813 torch .cuda .empty_cache ()
715814 generated_ids_after_ptq = None
716815 if generated_ids_before_ptq is None :
717816 pass
817+ elif model_type == "qwen3omni" :
818+ # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
819+ # Pass full batch with all multimodal inputs
820+ result = full_model .generate (** calib_batch , max_new_tokens = 100 )
821+ if isinstance (result , tuple ):
822+ text_ids , _ = result
823+ generated_ids_after_ptq = (
824+ text_ids .sequences if hasattr (text_ids , "sequences" ) else text_ids
825+ )
826+ else :
827+ generated_ids_after_ptq = result
718828 elif model_type != "llama4" and not is_nemotron_vl_model :
719829 # Our fake quantizer may not be fully compatible with torch.compile.
720830 generated_ids_after_ptq = full_model .generate (preview_input_ids , max_new_tokens = 100 )
@@ -733,12 +843,13 @@ def post_quantize(
733843 )
734844
735845 def input_decode (input_ids ):
736- if processor is not None and isinstance (processor , MllamaImageProcessor ):
737- return processor .tokenizer .batch_decode (input_ids )
846+ # BaseImageProcessor covers MllamaImageProcessor and Qwen3OmniImageProcessor
847+ if processor is not None and isinstance (processor , BaseImageProcessor ):
848+ return processor .tokenizer .batch_decode (input_ids , skip_special_tokens = True )
738849 elif processor is not None and isinstance (processor , WhisperProcessor ):
739850 return first_text_speech_dataset
740851 elif tokenizer is not None :
741- return tokenizer .batch_decode (input_ids )
852+ return tokenizer .batch_decode (input_ids , skip_special_tokens = True )
742853 else :
743854 raise ValueError ("The processor or tokenizer must be set" )
744855
@@ -750,6 +861,12 @@ def output_decode(generated_ids, input_shape):
750861 return tokenizer .batch_decode (generated_ids , skip_special_tokens = True )
751862 elif processor is not None and isinstance (processor , MllamaImageProcessor ):
752863 return processor .tokenizer .batch_decode (generated_ids [:, input_shape :])
864+ elif processor is not None and isinstance (processor , Qwen3OmniImageProcessor ):
865+ return processor .tokenizer .batch_decode (
866+ generated_ids [:, input_shape :],
867+ skip_special_tokens = True ,
868+ clean_up_tokenization_spaces = False ,
869+ )
753870 elif tokenizer is not None :
754871 return tokenizer .batch_decode (generated_ids [:, input_shape :])
755872 else :
@@ -831,7 +948,7 @@ def quantize_main(
831948 # Detect if this is a Nemotron VL model using architecture-based detection
832949 is_nemotron_vl_model = is_nemotron_vl (full_model )
833950
834- preview_input_ids , generated_ids_before_ptq = pre_quantize (
951+ preview_input_ids , generated_ids_before_ptq , calib_batch = pre_quantize (
835952 args , full_model , model_type , tokenizer , calib_dataloader , is_nemotron_vl_model
836953 )
837954
@@ -903,6 +1020,7 @@ def quantize_main(
9031020 generated_ids_before_ptq ,
9041021 is_nemotron_vl_model ,
9051022 first_text_speech_dataset ,
1023+ calib_batch ,
9061024 )
9071025 export_quantized (
9081026 args ,
@@ -1083,6 +1201,15 @@ def parse_args() -> argparse.Namespace:
10831201 "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
10841202 ),
10851203 )
1204+ parser .add_argument (
1205+ "--quant_summary_path" ,
1206+ type = str ,
1207+ default = None ,
1208+ help = (
1209+ "Path to save the quantization summary. If not specified, summary is printed to stdout. "
1210+ "Requires --verbose to be enabled (default: True)."
1211+ ),
1212+ )
10861213
10871214 return parser .parse_args ()
10881215
0 commit comments