3838"""
3939Quantize a timm vision model and export to ONNX for TensorRT deployment.
4040
41- Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.
41+ Supports FP8, INT8, MXFP8, NVFP4, and AUTO (mixed-precision) quantization modes end-to-end
42+ (quantize + ONNX export + TRT build). INT4_AWQ is quantize/export-only; it is not compatible
43+ with ``--trt_build``.
4244
4345The script will:
44461. Load a pretrained timm model (e.g., ViT, Swin, ResNet).
8890
8991# Auto-quantize format configs that use block quantization and need Conv2d overrides for TRT.
9092# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
91- _NEEDS_FP8_CONV_OVERRIDE : set [str ] = {"NVFP4_AWQ_LITE_CFG" , "NVFP4_DEFAULT_CFG" , "MXFP8_DEFAULT_CFG" }
93+ _NEEDS_FP8_CONV_OVERRIDE : set [str ] = {
94+ "NVFP4_AWQ_LITE_CFG" ,
95+ "NVFP4_DEFAULT_CFG" ,
96+ "MXFP8_DEFAULT_CFG" ,
97+ }
9298_NEEDS_INT8_CONV_OVERRIDE : set [str ] = {"INT4_AWQ_CFG" }
9399
94100
@@ -125,19 +131,20 @@ def filter_func(name):
125131 return pattern .match (name ) is not None
126132
127133
128- def load_calibration_data (model_name , data_size , batch_size , device , with_labels = False ):
134+ def load_calibration_data (model , data_size , batch_size , device , with_labels = False ):
129135 """Load and prepare calibration data.
130136
131137 Args:
132- model_name: Name of the timm model
138+ model: The timm model being quantized; used to derive the calibration transforms so the
139+ data pipeline matches the exact model config (respects --no_pretrained and
140+ --model_kwargs).
133141 data_size: Number of samples to load
134142 batch_size: Batch size for data loader
135143 device: Device to load data to
136144 with_labels: If True, return dict with 'image' and 'label' keys (for auto_quantize)
137145 If False, return just the images (for standard quantize)
138146 """
139147 dataset = load_dataset ("zh-plus/tiny-imagenet" )
140- model = timm .create_model (model_name , pretrained = True , num_classes = 1000 )
141148 data_config = timm .data .resolve_model_data_config (model )
142149 transforms = timm .data .create_transform (** data_config , is_training = False )
143150
@@ -171,11 +178,7 @@ def _calibrate_uncalibrated_quantizers(model, data_loader):
171178 if not hasattr (module , attr_name ):
172179 continue
173180 quantizer = getattr (module , attr_name )
174- if (
175- quantizer .is_enabled
176- and not quantizer .block_sizes
177- and not hasattr (quantizer , "_amax" )
178- ):
181+ if quantizer .is_enabled and not quantizer .block_sizes and quantizer .amax is None :
179182 quantizer .enable_calib ()
180183 uncalibrated .append (quantizer )
181184
@@ -204,11 +207,14 @@ def forward_loop(model):
204207 else :
205208 quantized_model = mtq .quantize (model , config )
206209
207- # Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize()
210+ # Disable filtered quantizers BEFORE calibrating override quantizers so we don't
211+ # waste time calibrating quantizers that are about to be turned off.
212+ mtq .disable_quantizer (quantized_model , filter_func )
213+
214+ # Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize().
208215 if data_loader is not None :
209216 _calibrate_uncalibrated_quantizers (quantized_model , data_loader )
210217
211- mtq .disable_quantizer (quantized_model , filter_func )
212218 return quantized_model
213219
214220
@@ -305,7 +311,10 @@ def get_model_input_shape(model):
305311
306312def main ():
307313 parser = argparse .ArgumentParser (
308- description = "Quantize timm models to FP8, MXFP8, INT8, NVFP4, INT4_AWQ, or use AUTO quantization"
314+ description = (
315+ "Quantize timm models to FP8, MXFP8, INT8, NVFP4, or use AUTO quantization. "
316+ "INT4_AWQ is supported for quantize/export only and is not compatible with --trt_build."
317+ )
309318 )
310319
311320 # Model hyperparameters
@@ -424,7 +433,7 @@ def main():
424433 if args .quantize_mode == "auto" :
425434 # Auto quantization requires labels for loss computation
426435 data_loader = load_calibration_data (
427- args . timm_model_name ,
436+ model ,
428437 args .calibration_data_size ,
429438 args .batch_size ,
430439 device ,
@@ -446,7 +455,7 @@ def main():
446455 # quantizers require calibration data.
447456 config = get_quant_config (args .quantize_mode )
448457 data_loader = load_calibration_data (
449- args . timm_model_name ,
458+ model ,
450459 args .calibration_data_size ,
451460 args .batch_size ,
452461 device ,
@@ -496,7 +505,14 @@ def build_trt_engine(onnx_path):
496505 "--builderOptimizationLevel=4" ,
497506 ]
498507 print (f"\n Building TensorRT engine: { ' ' .join (cmd )} " )
499- result = subprocess .run (cmd , capture_output = True , text = True , timeout = 600 )
508+ try :
509+ result = subprocess .run (cmd , capture_output = True , text = True , timeout = 600 )
510+ except FileNotFoundError as e :
511+ raise RuntimeError (
512+ "trtexec not found on PATH; install TensorRT or drop --trt_build."
513+ ) from e
514+ except subprocess .TimeoutExpired as e :
515+ raise RuntimeError (f"trtexec timed out building { onnx_path } after 600s." ) from e
500516 if result .returncode != 0 :
501517 raise RuntimeError (
502518 f"TensorRT engine build failed for { onnx_path } :\n { result .stdout } \n { result .stderr } "
0 commit comments