@@ -177,6 +177,19 @@ def _output_hook(module, input, output):
177177 "This is required for requantization/resmoothing optimization. "
178178 "Please ensure the model architecture is supported or file an issue."
179179 )
180+ elif "qwen3omni" in model_type :
181+ # For Qwen3Omni, run on the thinker (language model) component
182+ # The model has structure: model.thinker.model.layers.*
183+ if hasattr (model , "thinker" ):
184+ print (
185+ f"Running optimization on Qwen3Omni thinker with fake_input shape: { fake_input .shape } "
186+ )
187+ model .thinker (fake_input )
188+ else :
189+ raise ValueError (
190+ f"Cannot extract thinker from Qwen3Omni model (type: { model_type } ). "
191+ "This is required for requantization/resmoothing optimization."
192+ )
180193 else :
181194 model (fake_input )
182195
@@ -248,6 +261,19 @@ def _export_quantized_weight(
248261 weight_quantizer : TensorQuantizer | SequentialQuantizer = getattr (
249262 sub_module , quantizer_attrs .weight_quantizer
250263 )
264+
265+ # Skip export if weight quantizer is disabled or has no amax (not calibrated)
266+ if not _is_enabled_quantizer (weight_quantizer ):
267+ return
268+
269+ # Check if weight quantizer has calibrated amax
270+ def _has_amax (quantizer ):
271+ if isinstance (quantizer , SequentialQuantizer ):
272+ return any (hasattr (q , "_amax" ) and q ._amax is not None for q in quantizer )
273+ return hasattr (quantizer , "_amax" ) and quantizer ._amax is not None
274+
275+ if not _has_amax (weight_quantizer ):
276+ return
251277 input_quantizer : TensorQuantizer | SequentialQuantizer | None = getattr (
252278 sub_module , quantizer_attrs .input_quantizer , None
253279 )
@@ -392,7 +418,11 @@ def _export_quantized_weight(
392418
393419
394420def _export_hf_checkpoint (
395- model : nn .Module , dtype : torch .dtype | None = None , is_modelopt_qlora : bool = False , ** kwargs
421+ model : nn .Module ,
422+ dtype : torch .dtype | None = None ,
423+ is_modelopt_qlora : bool = False ,
424+ pack_weights : bool = True ,
425+ ** kwargs ,
396426) -> tuple [dict [str , Any ], dict [str , Any ]]:
397427 """Exports the torch model to the packed checkpoint with original HF naming.
398428
@@ -402,6 +432,7 @@ def _export_hf_checkpoint(
402432 model: the full torch model to export. The actual quantized model may be a submodule.
403433 dtype: the weights data type to export the unquantized layers or the default model data type if None.
404434 accelerator: the accelerator instance in case of distributed export setup.
435+ pack_weights: whether to pack quantized weights (False keeps original shapes for HF reload).
405436
406437 Returns:
407438 post_state_dict: Dict containing quantized weights
@@ -518,8 +549,9 @@ def _export_hf_checkpoint(
518549
519550 if get_quantization_format (sub_module ) != QUANTIZATION_NONE :
520551 if is_quantlinear (sub_module ):
521- with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
522- _export_quantized_weight (sub_module , dtype )
552+ if pack_weights :
553+ with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
554+ _export_quantized_weight (sub_module , dtype )
523555 elif (
524556 "Llama4TextExperts" in type (sub_module ).__name__
525557 or "GptOssExperts" in type (sub_module ).__name__
@@ -536,9 +568,10 @@ def _export_hf_checkpoint(
536568 quantizer_attrs = ["gate_up_proj_input_quantizer" , "down_proj_input_quantizer" ],
537569 )
538570 # Export the quantized weights
539- with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
540- for weight_name in ["gate_up_proj" , "down_proj" ]:
541- _export_quantized_weight (sub_module , dtype , weight_name )
571+ if pack_weights :
572+ with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
573+ for weight_name in ["gate_up_proj" , "down_proj" ]:
574+ _export_quantized_weight (sub_module , dtype , weight_name )
542575
543576 if accelerator is not None :
544577 # Gather state_dict from all ranks
@@ -579,7 +612,12 @@ def export_hf_checkpoint(
579612 return
580613
581614 try :
582- post_state_dict , hf_quant_config = _export_hf_checkpoint (model , dtype )
615+ # Packed weights are only for TRT-LLM consumption
616+ # Set this to true if you want to save the weights in the original precision
617+ pack_weights = True
618+ post_state_dict , hf_quant_config = _export_hf_checkpoint (
619+ model , dtype , pack_weights = pack_weights
620+ )
583621
584622 if hf_quant_config is not None :
585623 # Save hf_quant_config.json for\ backward compatibility
@@ -588,6 +626,16 @@ def export_hf_checkpoint(
588626
589627 hf_quant_config = convert_hf_quant_config_format (hf_quant_config )
590628
629+ # Fix generation_config conflicts before saving
630+ # Some models have temperature/top_p/top_k set but do_sample=False which causes validation errors
631+ if hasattr (model , "generation_config" ) and model .generation_config is not None :
632+ gen_config = model .generation_config
633+ if not getattr (gen_config , "do_sample" , True ):
634+ # Remove sampling-related params when do_sample is False
635+ for attr in ["temperature" , "top_p" , "top_k" ]:
636+ if hasattr (gen_config , attr ):
637+ setattr (gen_config , attr , None )
638+
591639 # Save model
592640 model .save_pretrained (
593641 export_dir , state_dict = post_state_dict , save_modelopt_state = save_modelopt_state
0 commit comments