6363 create_forward_loop ,
6464 get_dataset_dataloader ,
6565 get_max_batch_size ,
66- get_qwen3omni_text_dataloader ,
6766 get_supported_datasets ,
6867)
6968from modelopt .torch .utils .image_processor import (
7069 BaseImageProcessor ,
7170 MllamaImageProcessor ,
7271 Qwen3OmniImageProcessor ,
73- Qwen3OmniTextProcessor ,
7472)
7573from modelopt .torch .utils .memory_monitor import launch_memory_monitor
7674from modelopt .torch .utils .speech_dataset_utils import get_speech_dataset_dataloader
@@ -196,50 +194,47 @@ def make_calib_dataloader(
196194 num_samples = args .calib_size [0 ],
197195 )
198196 elif model_type == "qwen3omni" :
199- assert processor is not None , "The processor must be set for qwen3omni model."
200197 dataset_name = args .dataset [0 ] if args .dataset else "cnn_dailymail"
201198 # Check if using video dataset (e.g., finevideo)
202- if dataset_name in get_supported_video_datasets ():
203- video_processor = Qwen3OmniVideoProcessor (
204- processor .tokenizer if hasattr (processor , "tokenizer" ) else processor ,
205- device = device ,
206- dtype = language_model .dtype ,
207- use_audio_in_video = True ,
208- )
209- calib_dataloader = get_video_dataset_dataloader (
210- dataset_name = dataset_name ,
211- processor = video_processor ,
212- batch_size = args .batch_size ,
213- num_samples = args .calib_size [0 ],
214- )
215- elif dataset_name in get_supported_vlm_datasets ():
216- assert isinstance (processor , Qwen3OmniImageProcessor ), (
217- "The Qwen3OmniImageProcessor must be set."
218- )
219- # Set the dtype for proper tensor conversion in collate_function
220- processor .dtype = language_model .dtype
221- calib_dataloader = get_vlm_dataset_dataloader (
222- dataset_name = dataset_name ,
223- processor = processor ,
224- batch_size = args .batch_size ,
225- num_samples = args .calib_size [0 ],
226- )
199+ if processor is not None :
200+ if dataset_name in get_supported_video_datasets ():
201+ video_processor = Qwen3OmniVideoProcessor (
202+ processor .tokenizer if hasattr (processor , "tokenizer" ) else processor ,
203+ device = device ,
204+ dtype = language_model .dtype ,
205+ use_audio_in_video = True ,
206+ )
207+ calib_dataloader = get_video_dataset_dataloader (
208+ dataset_name = dataset_name ,
209+ processor = video_processor ,
210+ batch_size = args .batch_size ,
211+ num_samples = args .calib_size [0 ],
212+ )
213+ elif dataset_name in get_supported_vlm_datasets ():
214+ assert isinstance (processor , Qwen3OmniImageProcessor ), (
215+ "The Qwen3OmniImageProcessor must be set."
216+ )
217+ # Set the dtype for proper tensor conversion in collate_function
218+ processor .dtype = language_model .dtype
219+ calib_dataloader = get_vlm_dataset_dataloader (
220+ dataset_name = dataset_name ,
221+ processor = processor ,
222+ batch_size = args .batch_size ,
223+ num_samples = args .calib_size [0 ],
224+ )
227225 else :
228- # Text-only datasets (e.g., cnn_dailymail)
229- # Use Qwen3OmniTextProcessor to apply proper conversation template
230- # See: https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Thinking
231- text_processor = Qwen3OmniTextProcessor (
232- processor = processor .tokenizer , # Pass the underlying HF processor
233- device = device ,
234- dtype = language_model .dtype ,
226+ # Labels are only needed for gradient-based auto_quantize
227+ include_labels = (
228+ args .auto_quantize_bits is not None and args .auto_quantize_method == "gradient"
235229 )
236- calib_dataloader = get_qwen3omni_text_dataloader (
237- dataset_name = dataset_name ,
238- processor = text_processor ,
230+ calib_dataloader = get_dataset_dataloader (
231+ dataset_name = args . dataset ,
232+ tokenizer = tokenizer ,
239233 batch_size = args .batch_size ,
240- num_samples = args .calib_size [0 ],
234+ num_samples = args .calib_size ,
235+ device = device ,
236+ include_labels = include_labels ,
241237 )
242- print (f"Selected dataset for calibration: { dataset_name } " )
243238 elif model_type == "whisper" :
244239 assert processor is not None and isinstance (processor , WhisperProcessor ), (
245240 "The AutoProcessor must be set."
@@ -410,9 +405,6 @@ def load_model(args: argparse.Namespace):
410405 calibration_only = True
411406
412407 model_type = get_model_type (full_model )
413- if model_type == "qwen3omni" :
414- print ("Disabling talker for Qwen3Omni model" )
415- full_model .disable_talker ()
416408
417409 device = full_model .device
418410 if hasattr (full_model , "model" ):
@@ -432,6 +424,14 @@ def load_model(args: argparse.Namespace):
432424 trust_remote_code = args .trust_remote_code ,
433425 attn_implementation = args .attn_implementation ,
434426 )
427+ if model_type == "qwen3omni" :
428+ print ("Disabling talker for Qwen3Omni model" )
429+ full_model .disable_talker ()
430+ language_model = full_model .thinker .model
431+ tokenizer = processor .tokenizer .tokenizer
432+ processor = None
433+ default_padding_side = tokenizer .padding_side
434+ default_pad_token = tokenizer .pad_token
435435 elif model_type == "whisper" :
436436 processor = get_processor (
437437 args .pyt_ckpt_path ,
@@ -567,13 +567,16 @@ def mono_quantize(
567567 quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
568568
569569 # For Qwen3Omni models, disable quantization of conv layers
570+ generation_kwargs = {}
570571 if model_type == "qwen3omni" :
571572 print (
572573 "Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
573574 )
574575 quant_cfg ["quant_cfg" ]["*conv*" ] = {"enable" : False }
575576 quant_cfg ["quant_cfg" ]["*audio_tower*" ] = {"enable" : False }
576577 quant_cfg ["quant_cfg" ]["*visual*" ] = {"enable" : False }
578+ generation_kwargs ["return_audio" ] = False
579+ generation_kwargs ["thinker_max_new_tokens" ] = 1
577580
578581 if not model_is_already_quantized or calibration_only :
579582 if model_type == "gptoss" and args .qformat == "nvfp4_mlp_only" :
@@ -592,7 +595,9 @@ def mono_quantize(
592595 if args .calib_with_images and is_nemotron_vl_model :
593596 calibrate_loop = create_vlm_calibration_loop (full_model , calib_dataloader )
594597 else :
595- calibrate_loop = create_forward_loop (dataloader = calib_dataloader )
598+ calibrate_loop = create_forward_loop (
599+ dataloader = calib_dataloader , generation_kwargs = generation_kwargs
600+ )
596601
597602 if calibration_only :
598603 language_model = mtq .calibrate (
@@ -756,7 +761,7 @@ def pre_quantize(
756761 elif model_type == "qwen3omni" :
757762 # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
758763 # Pass full batch with all multimodal inputs
759- result = full_model .generate (** calib_batch , max_new_tokens = 100 )
764+ result = full_model .generate (** calib_batch , return_audio = False , thinker_max_new_tokens = 100 )
760765 if isinstance (result , tuple ):
761766 text_ids , _ = result
762767 generated_ids_before_ptq = (
@@ -817,7 +822,7 @@ def post_quantize(
817822 elif model_type == "qwen3omni" :
818823 # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
819824 # Pass full batch with all multimodal inputs
820- result = full_model .generate (** calib_batch , max_new_tokens = 100 )
825+ result = full_model .generate (** calib_batch , return_audio = False , thinker_max_new_tokens = 100 )
821826 if isinstance (result , tuple ):
822827 text_ids , _ = result
823828 generated_ids_after_ptq = (
0 commit comments