Skip to content

Commit eaa16a6

Browse files
ajrasaneclaude
andcommitted
Add EfficientViT support for torch_onnx quantization workflow
Add end-to-end support for efficientvit_l2 (Conv2d-heavy timm model) in the torch_onnx quantization-to-ONNX-to-TRT pipeline. This required several fixes to handle Conv2d layers with FP8 quantization: - Disable FP8 autocast during ONNX export to avoid dynamic shape issues - Disable Conv2d FP8 weight quantizer during ONNX export (TRT_FP8 custom ops produce dynamic shapes incompatible with ONNX Conv kernel shape requirement) - Add fix_fp16_fp32_mismatches() to insert Cast nodes resolving FP32/FP16 type mismatches after blocked-op FP16 conversion - Extend configure_linear_module_onnx_quantizers() to handle non-Linear modules with block-quantized input quantizers (e.g., pooling layers) - Add _disable_conv2d_dynamic_quantizers() to disable NVFP4/MXFP8 dynamic quantizers on Conv2d (TRT dynamic quantize requires 2D/3D, Conv2d is 4D) - Set calibration algorithm for MXFP8 Conv2d FP8 overrides - Add global_pool to filter_func exclusions - Relax is_fp8_quantized() to detect models with only input_quantizer FP8 Supported modes: FP8, INT8, MXFP8, NVFP4. Auto mode excluded due to Conv2d FP8 input/weight type mismatch in TRT stronglyTyped. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 202c3d3 commit eaa16a6

7 files changed

Lines changed: 272 additions & 58 deletions

File tree

examples/torch_onnx/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \
307307
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) |||||||
308308
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) |||||||
309309
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) |||||||
310+
| [efficientvit_l2](https://huggingface.co/timm/efficientvit_l2.r224_in1k) ||||| | |
310311

311312
## Resources
312313

examples/torch_onnx/torch_quant_to_onnx.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import json
1919
import re
20+
import subprocess
2021
import sys
2122
import warnings
2223
from pathlib import Path
@@ -96,6 +97,10 @@ def get_quant_config(quantize_mode):
9697
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
9798
)
9899
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
100+
# The FP8 Conv2d overrides use static quantization which requires
101+
# calibration (amax). Ensure the calibration algorithm is set.
102+
if config.get("algorithm") is None:
103+
config["algorithm"] = "max"
99104
elif quantize_mode == "int4_awq":
100105
warnings.warn(
101106
"TensorRT only supports FP8/INT8 for Conv layers. "
@@ -109,7 +114,8 @@ def filter_func(name):
109114
"""Filter function to exclude certain layers from quantization."""
110115
pattern = re.compile(
111116
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
112-
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
117+
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
118+
r"downsample|global_pool).*"
113119
)
114120
return pattern.match(name) is not None
115121

@@ -147,6 +153,21 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels
147153
)
148154

149155

156+
def _disable_conv2d_dynamic_quantizers(model):
157+
"""Disable dynamic block quantizers (NVFP4/MXFP8) on Conv2d modules.
158+
159+
TRT's FP4/MXFP8 DynamicQuantize only supports 2D/3D input tensors, but Conv2d
160+
layers have 4D inputs. Disable these quantizers to avoid TRT build failures.
161+
"""
162+
for name, module in model.named_modules():
163+
if not isinstance(module, torch.nn.Conv2d):
164+
continue
165+
for qname in ("input_quantizer", "weight_quantizer"):
166+
quantizer = getattr(module, qname, None)
167+
if quantizer is not None and getattr(quantizer, "block_sizes", None):
168+
quantizer.disable()
169+
170+
150171
def quantize_model(model, config, data_loader=None):
151172
"""Quantize the model using the given config and calibration data."""
152173
if data_loader is not None:
@@ -160,6 +181,7 @@ def forward_loop(model):
160181
quantized_model = mtq.quantize(model, config)
161182

162183
mtq.disable_quantizer(quantized_model, filter_func)
184+
_disable_conv2d_dynamic_quantizers(quantized_model)
163185
return quantized_model
164186

165187

@@ -235,6 +257,7 @@ def auto_quantize_model(
235257

236258
# Disable quantization for specified layers
237259
mtq.disable_quantizer(quantized_model, filter_func)
260+
_disable_conv2d_dynamic_quantizers(quantized_model)
238261

239262
return quantized_model, search_state
240263

@@ -320,6 +343,17 @@ def main():
320343
default=128,
321344
help="Number of scoring steps for auto quantization. Default is 128.",
322345
)
346+
parser.add_argument(
347+
"--trt_build",
348+
action="store_true",
349+
help="Build a TRT engine from the exported ONNX model to verify compatibility.",
350+
)
351+
parser.add_argument(
352+
"--trt_builder_opt_level",
353+
type=int,
354+
default=4,
355+
help="TRT builder optimization level (default: 4).",
356+
)
323357
parser.add_argument(
324358
"--no_pretrained",
325359
action="store_true",
@@ -378,18 +412,19 @@ def main():
378412
args.num_score_steps,
379413
)
380414
else:
381-
# Standard quantization - only load calibration data if needed
415+
# Standard quantization
382416
config = get_quant_config(args.quantize_mode)
383-
if args.quantize_mode == "mxfp8":
384-
data_loader = None
385-
else:
386-
data_loader = load_calibration_data(
387-
args.timm_model_name,
388-
args.calibration_data_size,
389-
args.batch_size,
390-
device,
391-
with_labels=False,
392-
)
417+
# Always load calibration data. Even though MXFP8 uses dynamic quantization
418+
# and doesn't strictly require calibration, the Conv2d FP8 overrides (applied
419+
# by get_quant_config for MXFP8/NVFP4) use static FP8 quantization which
420+
# needs calibration data to compute amax values.
421+
data_loader = load_calibration_data(
422+
args.timm_model_name,
423+
args.calibration_data_size,
424+
args.batch_size,
425+
device,
426+
with_labels=False,
427+
)
393428

394429
quantized_model = quantize_model(model, config, data_loader)
395430

@@ -421,6 +456,25 @@ def main():
421456

422457
print(f"Quantized ONNX model is saved to {args.onnx_save_path}")
423458

459+
if args.trt_build:
460+
print("\n=== Building TRT Engine ===")
461+
cmd = [
462+
"trtexec",
463+
f"--onnx={args.onnx_save_path}",
464+
"--stronglyTyped",
465+
f"--builderOptimizationLevel={args.trt_builder_opt_level}",
466+
]
467+
print(f"Running: {' '.join(cmd)}")
468+
result = subprocess.run(cmd, capture_output=True, text=True)
469+
if result.returncode != 0:
470+
print("TRT engine build FAILED:")
471+
for line in result.stderr.splitlines():
472+
if "Error" in line or "FAIL" in line or "error" in line:
473+
print(f" {line.strip()}")
474+
sys.exit(1)
475+
else:
476+
print("TRT engine build succeeded.")
477+
424478

425479
if __name__ == "__main__":
426480
main()

modelopt/onnx/utils.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,130 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
15041504
return onnx_model
15051505

15061506

1507+
def fix_fp16_fp32_mismatches(model: onnx.ModelProto) -> onnx.ModelProto:
1508+
"""Insert Cast nodes to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion.
1509+
1510+
After convert_float_to_float16 with an op_block_list, FP32 data from blocked ops
1511+
(e.g., QDQ paths) can flow into nodes whose other inputs are FP16. TensorRT
1512+
--stronglyTyped rejects such mismatches. This function propagates "real" types
1513+
through the graph and inserts FP32->FP16 Cast nodes where needed.
1514+
1515+
Note: value_info types are unreliable after convert_float_to_float16 with blocked ops
1516+
(metadata may say FP16 even when actual data is FP32), so this function re-derives
1517+
types by following op semantics.
1518+
1519+
Args:
1520+
model: The ONNX model to fix.
1521+
1522+
Returns:
1523+
The modified ONNX model with Cast nodes inserted to resolve mismatches.
1524+
"""
1525+
FLOAT = onnx.TensorProto.FLOAT
1526+
FLOAT16 = onnx.TensorProto.FLOAT16
1527+
1528+
# Ops whose data inputs must all have the same type in TRT stronglyTyped mode.
1529+
_ELEMENTWISE_OPS = {
1530+
"Add", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Equal", "Less",
1531+
"Greater", "Where", "Sum", "Mean", "Concat",
1532+
}
1533+
1534+
# Ops that are FP32-only (QDQ) — never cast their I/O.
1535+
_BLOCKED_OPS = {"QuantizeLinear", "DequantizeLinear"}
1536+
1537+
# --- Step 1: Propagate real element types through the graph. ---
1538+
real_type: dict[str, int] = {}
1539+
1540+
# Seed from graph inputs and initializers (these are authoritative).
1541+
for inp in model.graph.input:
1542+
real_type[inp.name] = inp.type.tensor_type.elem_type
1543+
for init in model.graph.initializer:
1544+
real_type[init.name] = init.data_type
1545+
1546+
# Process nodes in topological order.
1547+
for node in model.graph.node:
1548+
if node.op_type == "Constant":
1549+
for attr in node.attribute:
1550+
if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
1551+
for out in node.output:
1552+
real_type[out] = attr.t.data_type
1553+
continue
1554+
1555+
if node.op_type == "Cast":
1556+
cast_to = get_cast_to_type(node)
1557+
for out in node.output:
1558+
real_type[out] = cast_to
1559+
continue
1560+
1561+
if node.op_type in _BLOCKED_OPS:
1562+
for out in node.output:
1563+
real_type[out] = FLOAT
1564+
continue
1565+
1566+
# For other ops: output type matches the predominant data-input type.
1567+
data_types = []
1568+
for inp_name in node.input:
1569+
if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16):
1570+
data_types.append(real_type[inp_name])
1571+
1572+
if data_types:
1573+
out_type = FLOAT if FLOAT in data_types else FLOAT16
1574+
else:
1575+
out_type = FLOAT16
1576+
1577+
for out in node.output:
1578+
real_type[out] = out_type
1579+
1580+
# --- Step 2: Find nodes with mixed real types and insert Casts. ---
1581+
nodes_to_insert: list[tuple[int, onnx.NodeProto]] = []
1582+
1583+
for node_idx, node in enumerate(model.graph.node):
1584+
if node.op_type not in _ELEMENTWISE_OPS:
1585+
continue
1586+
1587+
input_real_types = []
1588+
for inp_name in node.input:
1589+
if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16):
1590+
input_real_types.append((inp_name, real_type[inp_name]))
1591+
1592+
if not input_real_types:
1593+
continue
1594+
1595+
has_fp32 = any(t == FLOAT for _, t in input_real_types)
1596+
has_fp16 = any(t == FLOAT16 for _, t in input_real_types)
1597+
if not (has_fp32 and has_fp16):
1598+
continue
1599+
1600+
# Insert Cast(FP32 -> FP16) for each FP32 input.
1601+
# Reuse existing Cast if the same input was already cast (avoids duplicate names).
1602+
for inp_idx, inp_name in enumerate(node.input):
1603+
if not inp_name or inp_name not in real_type:
1604+
continue
1605+
if real_type[inp_name] != FLOAT:
1606+
continue
1607+
cast_out_name = inp_name + "_cast_to_fp16"
1608+
if cast_out_name not in real_type:
1609+
cast_node = onnx.helper.make_node(
1610+
"Cast",
1611+
inputs=[inp_name],
1612+
outputs=[cast_out_name],
1613+
to=FLOAT16,
1614+
)
1615+
real_type[cast_out_name] = FLOAT16
1616+
nodes_to_insert.append((node_idx, cast_node))
1617+
node.input[inp_idx] = cast_out_name
1618+
1619+
# Insert cast nodes in reverse order so positions stay valid.
1620+
for pos, cast_node in sorted(nodes_to_insert, key=lambda x: x[0], reverse=True):
1621+
model.graph.node.insert(pos, cast_node)
1622+
1623+
if nodes_to_insert:
1624+
logger.info(
1625+
f"Inserted {len(nodes_to_insert)} Cast node(s) to fix FP32/FP16 mismatches"
1626+
)
1627+
1628+
return model
1629+
1630+
15071631
def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
15081632
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
15091633

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from modelopt.onnx.utils import (
4747
change_casts_to_fp16,
4848
check_model_uses_external_data,
49+
fix_fp16_fp32_mismatches,
4950
get_input_names,
5051
get_input_shapes,
5152
get_node_names,
@@ -382,22 +383,30 @@ def is_int8_quantized(model: nn.Module) -> bool:
382383
return False
383384

384385

386+
def _is_fp8_quantizer(quantizer) -> bool:
387+
"""Check if a single quantizer is configured for FP8 (not MXFP8)."""
388+
return (
389+
quantizer.is_enabled
390+
and quantizer._num_bits == (4, 3)
391+
and not (
392+
quantizer.block_sizes
393+
and quantizer.block_sizes.get("scale_bits", None) == (8, 0)
394+
)
395+
)
396+
397+
385398
def is_fp8_quantized(model: nn.Module) -> bool:
386-
"""Check if the model is quantized in FP8 mode."""
399+
"""Check if the model is quantized in FP8 mode.
400+
401+
Returns True if any module has an FP8-configured quantizer (weight or input).
402+
This covers mixed-precision scenarios (e.g., auto_quantize) where only the
403+
input_quantizer might be FP8 while the weight_quantizer is disabled or uses
404+
a different format.
405+
"""
387406
for _, module in model.named_modules():
388-
if (
389-
hasattr(module, "weight_quantizer")
390-
and hasattr(module, "input_quantizer")
391-
and module.weight_quantizer.is_enabled
392-
and module.input_quantizer.is_enabled
393-
and module.weight_quantizer._num_bits == (4, 3)
394-
and module.input_quantizer._num_bits == (4, 3)
395-
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
396-
and not (
397-
module.input_quantizer.block_sizes
398-
and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0)
399-
)
400-
):
407+
if hasattr(module, "weight_quantizer") and _is_fp8_quantizer(module.weight_quantizer):
408+
return True
409+
if hasattr(module, "input_quantizer") and _is_fp8_quantizer(module.input_quantizer):
401410
return True
402411
return False
403412

@@ -522,7 +531,10 @@ def get_onnx_bytes_and_metadata(
522531
input_none_names = list(set(tree_spec_input.names) - set(input_names))
523532

524533
use_torch_autocast = not (
525-
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
534+
is_fp4_quantized(model)
535+
or is_mxfp8_quantized(model)
536+
or is_fp8_quantized(model)
537+
or weights_dtype == "fp32"
526538
)
527539
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()
528540

@@ -556,6 +568,22 @@ def get_onnx_bytes_and_metadata(
556568
if is_fp4_quantized(model) or is_mxfp8_quantized(model)
557569
else nullcontext()
558570
)
571+
572+
# Disable Conv2d FP8 weight quantizer for ONNX export.
573+
# FP8 TRT_FP8QuantizeLinear/DequantizeLinear custom ops produce tensors with
574+
# dynamic shapes, and the ONNX Conv exporter requires static kernel shapes.
575+
# Disabling the weight quantizer keeps Conv2d weights as static constants in
576+
# the ONNX graph. Input quantizer remains enabled so TRT still uses FP8 for
577+
# Conv2d activations. Weights are converted to FP16 by post-export processing.
578+
conv_quantizers_to_reenable: list[tuple[nn.Module, str]] = []
579+
if is_fp8_quantized(model):
580+
for module in model.modules():
581+
if not isinstance(module, nn.Conv2d):
582+
continue
583+
quantizer = getattr(module, "weight_quantizer", None)
584+
if quantizer is not None and _is_fp8_quantizer(quantizer):
585+
quantizer.disable()
586+
conv_quantizers_to_reenable.append((module, "weight_quantizer"))
559587
with torch.inference_mode(), autocast, quantizer_context:
560588
additional_kwargs = {}
561589
if not dynamo_export:
@@ -571,6 +599,10 @@ def get_onnx_bytes_and_metadata(
571599
**additional_kwargs,
572600
)
573601

602+
# Re-enable Conv2d quantizers that were temporarily disabled for FP8 export
603+
for module, qname in conv_quantizers_to_reenable:
604+
getattr(module, qname).enable()
605+
574606
# Check that export worked
575607
assert len(os.listdir(onnx_path)) > 0, "Torch to onnx export failed."
576608

@@ -617,6 +649,16 @@ def get_onnx_bytes_and_metadata(
617649

618650
onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)
619651

652+
# Fix remaining FP32/FP16 mismatches AFTER remove_redundant_casts.
653+
# Only needed for the convert_float_to_float16 path (FP8/MXFP8/NVFP4) where
654+
# blocked QDQ ops produce FP32 that flows into nodes with FP16 inputs.
655+
# Must run after remove_redundant_casts because that function uses unreliable
656+
# value_info metadata and would incorrectly remove the Cast nodes we insert.
657+
if weights_dtype in ["fp16", "bf16"] and (
658+
is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model)
659+
):
660+
onnx_opt_graph = fix_fp16_fp32_mismatches(onnx_opt_graph)
661+
620662
# TensorRT expects all scales to be postive
621663
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
622664

0 commit comments

Comments
 (0)