@@ -300,29 +300,43 @@ def llm_dummy_forward():
300300 [1 , model .config .num_mel_bins , feature_extractor .nb_max_frames ], dtype = model .dtype
301301 ).to (model .device )
302302
303- if getattr (model .config , "is_encoder_decoder" , False ):
304- # For encoder-decoder models, we need to pass both the encoder and decoder input ids
305- model (fake_input , decoder_input_ids = decoder_fake_input )
306- elif is_vl_model and "nemotron" in model_type :
307- # For Nemotron VL models, try to run optimization on just the language model part
308- language_model_lineage = get_language_model_from_vl (model )
309-
310- if language_model_lineage is not None :
311- # Run optimization on just the language model with the same input format as regular LLMs
312- # Use the same fake_input tensor that regular LLMs use
313- language_model = language_model_lineage [- 1 ]
314- print (
315- f"Running optimization on language model with fake_input shape: { fake_input .shape } "
316- )
317- language_model (fake_input )
303+ with set_quantizer_by_cfg_context (model , {"*" : {"enable" : False }}):
304+ if getattr (model .config , "is_encoder_decoder" , False ):
305+ # For encoder-decoder models, we need to pass both the encoder and decoder input ids
306+ model (fake_input , decoder_input_ids = decoder_fake_input )
307+ elif is_vl_model and "nemotron" in model_type :
308+ # For Nemotron VL models, try to run optimization on just the language model part
309+ language_model_lineage = get_language_model_from_vl (model )
310+
311+ if language_model_lineage is not None :
312+ # Run optimization on just the language model with the same input format as regular LLMs
313+ # Use the same fake_input tensor that regular LLMs use
314+ language_model = language_model_lineage [- 1 ]
315+ print (
316+ f"Running optimization on language model with fake_input shape: { fake_input .shape } "
317+ )
318+ language_model (fake_input )
319+ else :
320+ raise ValueError (
321+ f"Cannot extract language_model from Nemotron VL model (type: { model_type } ). "
322+ "This is required for requantization/resmoothing optimization. "
323+ "Please ensure the model architecture is supported or file an issue."
324+ )
325+ elif "qwen3omni" in model_type :
326+ # For Qwen3Omni, run on the thinker (language model) component
327+ # The model has structure: model.thinker.model.layers.*
328+ if hasattr (model , "thinker" ):
329+ print (
330+ f"Running optimization on Qwen3Omni thinker with fake_input shape: { fake_input .shape } "
331+ )
332+ model .thinker (fake_input )
333+ else :
334+ raise ValueError (
335+ f"Cannot extract thinker from Qwen3Omni model (type: { model_type } ). "
336+ "This is required for requantization/resmoothing optimization."
337+ )
318338 else :
319- raise ValueError (
320- f"Cannot extract language_model from Nemotron VL model (type: { model_type } ). "
321- "This is required for requantization/resmoothing optimization. "
322- "Please ensure the model architecture is supported or file an issue."
323- )
324- else :
325- model (fake_input )
339+ model (fake_input )
326340
327341 input_to_linear , output_to_layernorm = _collect_shared_input_modules (
328342 model , llm_dummy_forward , collect_layernorms = True
@@ -380,6 +394,19 @@ def _export_quantized_weight(
380394 weight_quantizer : TensorQuantizer | SequentialQuantizer = getattr (
381395 sub_module , quantizer_attrs .weight_quantizer
382396 )
397+
398+ # Skip export if weight quantizer is disabled or has no amax (not calibrated)
399+ if not _is_enabled_quantizer (weight_quantizer ):
400+ return
401+
402+ # Check if weight quantizer has calibrated amax
403+ def _has_amax (quantizer ):
404+ if isinstance (quantizer , SequentialQuantizer ):
405+ return any (hasattr (q , "_amax" ) and q ._amax is not None for q in quantizer )
406+ return hasattr (quantizer , "_amax" ) and quantizer ._amax is not None
407+
408+ if not _has_amax (weight_quantizer ):
409+ return
383410 input_quantizer : TensorQuantizer | SequentialQuantizer | None = getattr (
384411 sub_module , quantizer_attrs .input_quantizer , None
385412 )
@@ -543,6 +570,7 @@ def _process_quantized_modules(
543570 model : nn .Module ,
544571 dtype : torch .dtype ,
545572 is_modelopt_qlora : bool = False ,
573+ pack_weights : bool = True ,
546574) -> None :
547575 """Process all quantized modules in model, export weights in-place.
548576
@@ -555,6 +583,7 @@ def _process_quantized_modules(
555583 dtype: The data type for weight conversion.
556584 is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model.
557585 If True, modules with base_layer attribute are skipped.
586+ pack_weights: Whether to pack quantized weights.
558587 """
559588 fsdp_module_to_reshard = None
560589
@@ -577,8 +606,9 @@ def _process_quantized_modules(
577606 sub_module .unpack_weight ()
578607 if get_quantization_format (sub_module ) != QUANTIZATION_NONE :
579608 if is_quantlinear (sub_module ):
580- with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
581- _export_quantized_weight (sub_module , dtype )
609+ if pack_weights :
610+ with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
611+ _export_quantized_weight (sub_module , dtype )
582612 elif (
583613 "Llama4TextExperts" in type (sub_module ).__name__
584614 or "GptOssExperts" in type (sub_module ).__name__
@@ -595,13 +625,18 @@ def _process_quantized_modules(
595625 quantizer_attrs = ["gate_up_proj_input_quantizer" , "down_proj_input_quantizer" ],
596626 )
597627 # Export the quantized weights
598- with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
599- for weight_name in ["gate_up_proj" , "down_proj" ]:
600- _export_quantized_weight (sub_module , dtype , weight_name )
628+ if pack_weights :
629+ with fsdp2_aware_weight_update (model , sub_module , reshard = False ):
630+ for weight_name in ["gate_up_proj" , "down_proj" ]:
631+ _export_quantized_weight (sub_module , dtype , weight_name )
601632
602633
603- def _export_transformers_checkpoint (
604- model : nn .Module , dtype : torch .dtype | None = None , is_modelopt_qlora : bool = False , ** kwargs
634+ def _export_hf_checkpoint (
635+ model : nn .Module ,
636+ dtype : torch .dtype | None = None ,
637+ is_modelopt_qlora : bool = False ,
638+ pack_weights : bool = True ,
639+ ** kwargs ,
605640) -> tuple [dict [str , Any ], dict [str , Any ]]:
606641 """Exports the torch model to the packed checkpoint with original HF naming.
607642
@@ -611,6 +646,7 @@ def _export_transformers_checkpoint(
611646 model: the full torch model to export. The actual quantized model may be a submodule.
612647 dtype: the weights data type to export the unquantized layers or the default model data type if None.
613648 accelerator: the accelerator instance in case of distributed export setup.
649+ pack_weights: whether to pack quantized weights (False keeps original shapes for HF reload).
614650
615651 Returns:
616652 post_state_dict: Dict containing quantized weights
@@ -695,7 +731,7 @@ def _export_transformers_checkpoint(
695731 quant_config = get_quant_config (model , is_modelopt_qlora = is_modelopt_qlora )
696732
697733 # Process all quantized modules and export weights
698- _process_quantized_modules (model , dtype , is_modelopt_qlora )
734+ _process_quantized_modules (model , dtype , is_modelopt_qlora , pack_weights )
699735
700736 if accelerator is not None :
701737 # Gather state_dict from all ranks
@@ -964,7 +1000,12 @@ def export_hf_checkpoint(
9641000 return
9651001
9661002 try :
967- post_state_dict , hf_quant_config = _export_transformers_checkpoint (model , dtype )
1003+ # Packed weights are only for TRT-LLM consumption
1004+ # Set this to true if you want to save the weights in the original precision
1005+ pack_weights = True
1006+ post_state_dict , hf_quant_config = _export_hf_checkpoint (
1007+ model , dtype , pack_weights = pack_weights
1008+ )
9681009
9691010 if hf_quant_config is not None :
9701011 # Save hf_quant_config.json for backward compatibility
@@ -977,6 +1018,16 @@ def export_hf_checkpoint(
9771018 if getattr (model , "hf_quantizer" , None ) is not None :
9781019 model .hf_quantizer = None
9791020
1021+ # Fix generation_config conflicts before saving
1022+ # Some models have temperature/top_p/top_k set but do_sample=False which causes validation errors
1023+ if hasattr (model , "generation_config" ) and model .generation_config is not None :
1024+ gen_config = model .generation_config
1025+ if not getattr (gen_config , "do_sample" , True ):
1026+ # Remove sampling-related params when do_sample is False
1027+ for attr in ["temperature" , "top_p" , "top_k" ]:
1028+ if hasattr (gen_config , attr ):
1029+ setattr (gen_config , attr , None )
1030+
9801031 # Save model
9811032 model .save_pretrained (
9821033 export_dir , state_dict = post_state_dict , save_modelopt_state = save_modelopt_state
0 commit comments