1818import random
1919import time
2020import warnings
21+ from collections import namedtuple
2122from typing import Any
2223
2324import numpy as np
3536 is_enc_dec ,
3637 is_nemotron_vl ,
3738 load_mtp_weights ,
38- patch_config_for_unified_export ,
3939 run_nemotron_vl_preview ,
4040)
4141from torch .utils .data import DataLoader
@@ -735,9 +735,6 @@ def export_quantized(
735735 extra_state_dict = mtp_state_dict ,
736736 )
737737
738- # Exclude non-quantized modules in config.json and hf_quant_config.json
739- patch_config_for_unified_export (model_type , export_path )
740-
741738 # Restore default padding and export the tokenizer as well.
742739 if tokenizer is not None :
743740 tokenizer .padding_side = default_padding_side
@@ -757,6 +754,23 @@ def export_quantized(
757754 )
758755
759756
757+ PreQuantizeResult = namedtuple (
758+ "PreQuantizeResult" , ["preview_input_ids" , "generated_ids_before_ptq" , "calib_batch" ]
759+ )
760+
761+
762+ def _qwen3omni_generate (model , calib_batch ):
763+ """Run Qwen3Omni generate and unpack the result.
764+
765+ Qwen3Omni returns a (text_ids, audio) tuple; text_ids may have a .sequences attribute.
766+ """
767+ result = model .generate (** calib_batch , return_audio = False , thinker_max_new_tokens = 100 )
768+ if isinstance (result , tuple ):
769+ text_ids , _ = result
770+ return text_ids .sequences if hasattr (text_ids , "sequences" ) else text_ids
771+ return result
772+
773+
760774def pre_quantize (
761775 args : argparse .Namespace ,
762776 full_model : torch .nn .Module ,
@@ -799,20 +813,15 @@ def pre_quantize(
799813 allow_fallback = False ,
800814 )
801815 elif model_type == "qwen3omni" :
802- # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
803- # Pass full batch with all multimodal inputs
804- result = full_model .generate (** calib_batch , return_audio = False , thinker_max_new_tokens = 100 )
805- if isinstance (result , tuple ):
806- text_ids , _ = result
807- generated_ids_before_ptq = (
808- text_ids .sequences if hasattr (text_ids , "sequences" ) else text_ids
809- )
810- else :
811- generated_ids_before_ptq = result
816+ # Use only a single sample for preview generation to avoid OOM
817+ single_sample = {
818+ k : v [0 :1 ] if isinstance (v , torch .Tensor ) else v for k , v in calib_batch .items ()
819+ }
820+ generated_ids_before_ptq = _qwen3omni_generate (full_model , single_sample )
812821 else :
813822 generated_ids_before_ptq = full_model .generate (preview_input_ids , max_new_tokens = 100 )
814823
815- return preview_input_ids , generated_ids_before_ptq , calib_batch
824+ return PreQuantizeResult ( preview_input_ids , generated_ids_before_ptq , calib_batch )
816825
817826
818827def post_quantize (
@@ -861,25 +870,23 @@ def post_quantize(
861870 """
862871
863872 if args .verbose :
864- mtq .print_quant_summary (full_model , save_path = args .quant_summary_path )
865- save_expert_token_count_table (full_model , args .export_path )
873+ try :
874+ mtq .print_quant_summary (full_model , save_path = args .quant_summary_path )
875+ save_expert_token_count_table (full_model , args .export_path )
876+ except Exception as e :
877+ print (f"Warning: Failed to print quant summary: { e } " )
866878
867879 # Run some samples
868880 torch .cuda .empty_cache ()
869881 generated_ids_after_ptq = None
870882 if generated_ids_before_ptq is None :
871883 pass
872- elif model_type == "qwen3omni" :
873- # Qwen3Omni returns (text_ids, audio) tuple; text_ids has .sequences
874- # Pass full batch with all multimodal inputs
875- result = full_model .generate (** calib_batch , return_audio = False , thinker_max_new_tokens = 100 )
876- if isinstance (result , tuple ):
877- text_ids , _ = result
878- generated_ids_after_ptq = (
879- text_ids .sequences if hasattr (text_ids , "sequences" ) else text_ids
880- )
881- else :
882- generated_ids_after_ptq = result
884+ elif model_type == "qwen3omni" and calib_batch is not None :
885+ # Use only a single sample for preview generation to avoid OOM
886+ single_sample = {
887+ k : v [0 :1 ] if isinstance (v , torch .Tensor ) else v for k , v in calib_batch .items ()
888+ }
889+ generated_ids_after_ptq = _qwen3omni_generate (full_model , single_sample )
883890 elif model_type != "llama4" and not is_nemotron_vl_model :
884891 # Our fake quantizer may not be fully compatible with torch.compile.
885892 generated_ids_after_ptq = full_model .generate (preview_input_ids , max_new_tokens = 100 )
0 commit comments