Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/torch_onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [efficientvit_l2](https://huggingface.co/timm/efficientvit_l2.r224_in1k) | ✅ | ✅ | ✅ | ✅ | | |

## Resources

Expand Down
78 changes: 66 additions & 12 deletions examples/torch_onnx/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import json
import re
import subprocess
import sys
import warnings
from pathlib import Path
Expand Down Expand Up @@ -96,6 +97,10 @@ def get_quant_config(quantize_mode):
f"Overriding Conv2d quantization to FP8 for '{quantize_mode}' mode."
)
config["quant_cfg"].extend(_FP8_CONV_OVERRIDE)
# The FP8 Conv2d overrides use static quantization which requires
# calibration (amax). Ensure the calibration algorithm is set.
if config.get("algorithm") is None:
config["algorithm"] = "max"
elif quantize_mode == "int4_awq":
warnings.warn(
"TensorRT only supports FP8/INT8 for Conv layers. "
Expand All @@ -109,7 +114,8 @@ def filter_func(name):
"""Filter function to exclude certain layers from quantization."""
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
r"downsample|global_pool).*"
)
return pattern.match(name) is not None

Expand Down Expand Up @@ -147,6 +153,21 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels
)


def _disable_conv2d_dynamic_quantizers(model):
"""Disable dynamic block quantizers (NVFP4/MXFP8) on Conv2d modules.

TRT's FP4/MXFP8 DynamicQuantize only supports 2D/3D input tensors, but Conv2d
layers have 4D inputs. Disable these quantizers to avoid TRT build failures.
"""
for name, module in model.named_modules():
if not isinstance(module, torch.nn.Conv2d):
continue
for qname in ("input_quantizer", "weight_quantizer"):
quantizer = getattr(module, qname, None)
if quantizer is not None and getattr(quantizer, "block_sizes", None):
quantizer.disable()


def quantize_model(model, config, data_loader=None):
"""Quantize the model using the given config and calibration data."""
if data_loader is not None:
Expand All @@ -160,6 +181,7 @@ def forward_loop(model):
quantized_model = mtq.quantize(model, config)

mtq.disable_quantizer(quantized_model, filter_func)
_disable_conv2d_dynamic_quantizers(quantized_model)
return quantized_model


Expand Down Expand Up @@ -235,6 +257,7 @@ def auto_quantize_model(

# Disable quantization for specified layers
mtq.disable_quantizer(quantized_model, filter_func)
_disable_conv2d_dynamic_quantizers(quantized_model)

return quantized_model, search_state

Expand Down Expand Up @@ -320,6 +343,17 @@ def main():
default=128,
help="Number of scoring steps for auto quantization. Default is 128.",
)
parser.add_argument(
"--trt_build",
action="store_true",
help="Build a TRT engine from the exported ONNX model to verify compatibility.",
)
parser.add_argument(
"--trt_builder_opt_level",
type=int,
default=4,
help="TRT builder optimization level (default: 4).",
)
parser.add_argument(
"--no_pretrained",
action="store_true",
Expand Down Expand Up @@ -378,18 +412,19 @@ def main():
args.num_score_steps,
)
else:
# Standard quantization - only load calibration data if needed
# Standard quantization
config = get_quant_config(args.quantize_mode)
if args.quantize_mode == "mxfp8":
data_loader = None
else:
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)
# Always load calibration data. Even though MXFP8 uses dynamic quantization
# and doesn't strictly require calibration, the Conv2d FP8 overrides (applied
# by get_quant_config for MXFP8/NVFP4) use static FP8 quantization which
# needs calibration data to compute amax values.
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)
Comment on lines +417 to +427
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

--no_pretrained is no longer honored in the standard path.

Because Line 421 now always calls load_calibration_data(), the non-auto flow will still instantiate timm.create_model(..., pretrained=True) from Line 135 even when --no_pretrained is set. That turns a local/offline smoke run into a networked weights fetch.

💡 Suggested fix
-def load_calibration_data(model_name, data_size, batch_size, device, with_labels=False):
+def load_calibration_data(
+    model_name, data_size, batch_size, device, with_labels=False, pretrained=True
+):
@@
-    model = timm.create_model(model_name, pretrained=True, num_classes=1000)
+    model = timm.create_model(model_name, pretrained=pretrained, num_classes=1000)
         data_loader = load_calibration_data(
             args.timm_model_name,
             args.calibration_data_size,
             args.batch_size,
             device,
             with_labels=False,
+            pretrained=not args.no_pretrained,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 417 - 427, The code
always calls load_calibration_data which causes timm.create_model(...,
pretrained=True) to be instantiated earlier and ignores --no_pretrained; fix by
guarding the calibration load so it only runs when pretrained weights are
allowed: change the call site of load_calibration_data to run only if not
args.no_pretrained (or if the auto calibration flow explicitly requires
pretrained weights), e.g., wrap the existing load_calibration_data invocation in
a conditional that checks args.no_pretrained (and any auto-mode flag you use) so
timm.create_model is not triggered when --no_pretrained is set.


quantized_model = quantize_model(model, config, data_loader)

Expand Down Expand Up @@ -421,6 +456,25 @@ def main():

print(f"Quantized ONNX model is saved to {args.onnx_save_path}")

if args.trt_build:
print("\n=== Building TRT Engine ===")
cmd = [
"trtexec",
f"--onnx={args.onnx_save_path}",
"--stronglyTyped",
f"--builderOptimizationLevel={args.trt_builder_opt_level}",
]
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print("TRT engine build FAILED:")
for line in result.stderr.splitlines():
if "Error" in line or "FAIL" in line or "error" in line:
print(f" {line.strip()}")
sys.exit(1)
else:
print("TRT engine build succeeded.")


if __name__ == "__main__":
main()
124 changes: 124 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,130 @@ def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
return onnx_model


def fix_fp16_fp32_mismatches(model: onnx.ModelProto) -> onnx.ModelProto:
"""Insert Cast nodes to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion.

After convert_float_to_float16 with an op_block_list, FP32 data from blocked ops
(e.g., QDQ paths) can flow into nodes whose other inputs are FP16. TensorRT
--stronglyTyped rejects such mismatches. This function propagates "real" types
through the graph and inserts FP32->FP16 Cast nodes where needed.

Note: value_info types are unreliable after convert_float_to_float16 with blocked ops
(metadata may say FP16 even when actual data is FP32), so this function re-derives
types by following op semantics.

Args:
model: The ONNX model to fix.

Returns:
The modified ONNX model with Cast nodes inserted to resolve mismatches.
"""
FLOAT = onnx.TensorProto.FLOAT
FLOAT16 = onnx.TensorProto.FLOAT16

# Ops whose data inputs must all have the same type in TRT stronglyTyped mode.
_ELEMENTWISE_OPS = {
"Add", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Equal", "Less",
"Greater", "Where", "Sum", "Mean", "Concat",
}

# Ops that are FP32-only (QDQ) — never cast their I/O.
_BLOCKED_OPS = {"QuantizeLinear", "DequantizeLinear"}

# --- Step 1: Propagate real element types through the graph. ---
real_type: dict[str, int] = {}

# Seed from graph inputs and initializers (these are authoritative).
for inp in model.graph.input:
real_type[inp.name] = inp.type.tensor_type.elem_type
for init in model.graph.initializer:
real_type[init.name] = init.data_type

# Process nodes in topological order.
for node in model.graph.node:
if node.op_type == "Constant":
for attr in node.attribute:
if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
for out in node.output:
real_type[out] = attr.t.data_type
continue

if node.op_type == "Cast":
cast_to = get_cast_to_type(node)
for out in node.output:
real_type[out] = cast_to
continue

if node.op_type in _BLOCKED_OPS:
for out in node.output:
real_type[out] = FLOAT
continue

# For other ops: output type matches the predominant data-input type.
data_types = []
for inp_name in node.input:
if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16):
data_types.append(real_type[inp_name])

if data_types:
out_type = FLOAT if FLOAT in data_types else FLOAT16
else:
out_type = FLOAT16

for out in node.output:
real_type[out] = out_type
Comment on lines +1528 to +1578
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import onnx

for op in ("Equal", "Less", "Greater"):
    schema = onnx.defs.get_schema(op, 20)
    print(f"\n{op}")
    for out in schema.outputs:
        print(" output:", out.name, out.type_str)
    for constraint in schema.type_constraints:
        print(" constraint:", constraint.type_param_str, list(constraint.allowed_type_strs))
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 182


🏁 Script executed:

# First, examine the actual code in the file to understand the current state
cd /tmp && find . -name "utils.py" -path "*/modelopt/onnx/*" 2>/dev/null | head -1

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's locate and examine the file
find . -name "utils.py" -path "*/modelopt/onnx/*" | head -1 | xargs wc -l

Repository: NVIDIA/Model-Optimizer

Length of output: 103


🏁 Script executed:

# Examine the specific lines mentioned in the review (1528-1578)
find . -name "utils.py" -path "*/modelopt/onnx/*" | head -1 | xargs sed -n '1500,1600p'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Try installing onnx and running the verification script with more context
pip install -q onnx 2>&1 | tail -5
python - <<'PY'
try:
    import onnx
    print("ONNX installed successfully")
    for op in ("Equal", "Less", "Greater", "Add"):
        schema = onnx.defs.get_schema(op, 20)
        print(f"\n{op}:")
        for out in schema.outputs:
            print(f"  output: {out.name}, type: {out.type_str}")
except Exception as e:
    print(f"Error: {e}")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 464


🏁 Script executed:

# Get full ONNX schema with all type constraints for comparison ops
python - <<'PY'
import onnx

for op in ("Equal", "Less", "Greater"):
    schema = onnx.defs.get_schema(op, 20)
    print(f"\n{op}:")
    print(f"  Outputs:")
    for out in schema.outputs:
        print(f"    {out.name}: {out.type_str}")
    print(f"  Type Constraints:")
    for constraint in schema.type_constraints:
        print(f"    {constraint.type_param_str}: {constraint.allowed_type_strs}")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 982


🏁 Script executed:

# Now locate and read the actual file with the function
fd utils.py modelopt/onnx | head -1 | xargs wc -l

Repository: NVIDIA/Model-Optimizer

Length of output: 101


🏁 Script executed:

# Read the actual function around lines 1528-1578
fd utils.py modelopt/onnx | head -1 | xargs sed -n '1520,1590p'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Find the correct file - look for utils.py in modelopt/onnx (not just autocast subdir)
find modelopt/onnx -name "utils.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 165


🏁 Script executed:

# Also search for the function name to locate it
rg "fix_fp16_fp32_mismatches" modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 343


🏁 Script executed:

# Read the actual function from the correct file
sed -n '1500,1650p' modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 5799


🏁 Script executed:

# Verify if there are other comparison operators that should also output BOOL
python - <<'PY'
import onnx

comparison_ops = ["Equal", "Less", "Greater", "LessOrEqual", "GreaterOrEqual"]
for op in comparison_ops:
    try:
        schema = onnx.defs.get_schema(op, 20)
        print(f"{op}:")
        for constraint in schema.type_constraints:
            if constraint.type_param_str == "T1" or "bool" in str(constraint.allowed_type_strs):
                print(f"  {constraint.type_param_str}: {constraint.allowed_type_strs}")
    except Exception as e:
        print(f"{op}: Error - {e}")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 483


🏁 Script executed:

# Check if LessOrEqual and GreaterOrEqual are also in the code
grep -n "LessOrEqual\|GreaterOrEqual" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Double-check the logic of the fallback case in the code
# to confirm comparison ops are being treated like regular ops
sed -n '1555,1575p' modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 804


Special-case comparison outputs as BOOL.

The type propagation in Step 1 treats Equal, Less, and Greater outputs using the default fallback (predominant input type), but ONNX comparison operators always output tensor(bool). This causes their outputs to be incorrectly labeled as FLOAT or FLOAT16, which mislabels boolean masks and can cause downstream Where conditions to receive an invalid cast insertion.

🐛 Suggested fix
     for node in model.graph.node:
         if node.op_type == "Constant":
@@
         if node.op_type in _BLOCKED_OPS:
             for out in node.output:
                 real_type[out] = FLOAT
             continue
+
+        if node.op_type in {"Equal", "Less", "Greater"}:
+            for out in node.output:
+                real_type[out] = onnx.TensorProto.BOOL
+            continue
 
         # For other ops: output type matches the predominant data-input type.
         data_types = []
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1528 - 1578, The propagation code
incorrectly treats comparison ops ("Equal", "Less", "Greater") like numeric
elementwise ops and assigns FLOAT/FLOAT16; update the Step 1 propagation so that
when node.op_type is one of "Equal", "Less", "Greater" you set real_type[out] =
BOOL for each node.output (similar to how "Cast" and _BLOCKED_OPS are handled)
to ensure comparison outputs are recorded as boolean tensors and downstream
nodes like "Where" receive correct mask types.


# --- Step 2: Find nodes with mixed real types and insert Casts. ---
nodes_to_insert: list[tuple[int, onnx.NodeProto]] = []

for node_idx, node in enumerate(model.graph.node):
if node.op_type not in _ELEMENTWISE_OPS:
continue

input_real_types = []
for inp_name in node.input:
if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16):
input_real_types.append((inp_name, real_type[inp_name]))

if not input_real_types:
continue

has_fp32 = any(t == FLOAT for _, t in input_real_types)
has_fp16 = any(t == FLOAT16 for _, t in input_real_types)
if not (has_fp32 and has_fp16):
continue

# Insert Cast(FP32 -> FP16) for each FP32 input.
# Reuse existing Cast if the same input was already cast (avoids duplicate names).
for inp_idx, inp_name in enumerate(node.input):
if not inp_name or inp_name not in real_type:
continue
if real_type[inp_name] != FLOAT:
continue
cast_out_name = inp_name + "_cast_to_fp16"
if cast_out_name not in real_type:
cast_node = onnx.helper.make_node(
"Cast",
inputs=[inp_name],
outputs=[cast_out_name],
to=FLOAT16,
)
real_type[cast_out_name] = FLOAT16
nodes_to_insert.append((node_idx, cast_node))
node.input[inp_idx] = cast_out_name

# Insert cast nodes in reverse order so positions stay valid.
for pos, cast_node in sorted(nodes_to_insert, key=lambda x: x[0], reverse=True):
model.graph.node.insert(pos, cast_node)

if nodes_to_insert:
logger.info(
f"Inserted {len(nodes_to_insert)} Cast node(s) to fix FP32/FP16 mismatches"
)

return model


def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto:
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.

Expand Down
72 changes: 57 additions & 15 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from modelopt.onnx.utils import (
change_casts_to_fp16,
check_model_uses_external_data,
fix_fp16_fp32_mismatches,
get_input_names,
get_input_shapes,
get_node_names,
Expand Down Expand Up @@ -382,22 +383,30 @@ def is_int8_quantized(model: nn.Module) -> bool:
return False


def _is_fp8_quantizer(quantizer) -> bool:
"""Check if a single quantizer is configured for FP8 (not MXFP8)."""
return (
quantizer.is_enabled
and quantizer._num_bits == (4, 3)
and not (
quantizer.block_sizes
and quantizer.block_sizes.get("scale_bits", None) == (8, 0)
)
)


def is_fp8_quantized(model: nn.Module) -> bool:
"""Check if the model is quantized in FP8 mode."""
"""Check if the model is quantized in FP8 mode.

Returns True if any module has an FP8-configured quantizer (weight or input).
This covers mixed-precision scenarios (e.g., auto_quantize) where only the
input_quantizer might be FP8 while the weight_quantizer is disabled or uses
a different format.
"""
for _, module in model.named_modules():
if (
hasattr(module, "weight_quantizer")
and hasattr(module, "input_quantizer")
and module.weight_quantizer.is_enabled
and module.input_quantizer.is_enabled
and module.weight_quantizer._num_bits == (4, 3)
and module.input_quantizer._num_bits == (4, 3)
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
and not (
module.input_quantizer.block_sizes
and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0)
)
):
if hasattr(module, "weight_quantizer") and _is_fp8_quantizer(module.weight_quantizer):
return True
if hasattr(module, "input_quantizer") and _is_fp8_quantizer(module.input_quantizer):
return True
return False

Expand Down Expand Up @@ -522,7 +531,10 @@ def get_onnx_bytes_and_metadata(
input_none_names = list(set(tree_spec_input.names) - set(input_names))

use_torch_autocast = not (
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
is_fp4_quantized(model)
or is_mxfp8_quantized(model)
or is_fp8_quantized(model)
or weights_dtype == "fp32"
)
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()

Expand Down Expand Up @@ -556,6 +568,22 @@ def get_onnx_bytes_and_metadata(
if is_fp4_quantized(model) or is_mxfp8_quantized(model)
else nullcontext()
)

# Disable Conv2d FP8 weight quantizer for ONNX export.
# FP8 TRT_FP8QuantizeLinear/DequantizeLinear custom ops produce tensors with
# dynamic shapes, and the ONNX Conv exporter requires static kernel shapes.
# Disabling the weight quantizer keeps Conv2d weights as static constants in
# the ONNX graph. Input quantizer remains enabled so TRT still uses FP8 for
# Conv2d activations. Weights are converted to FP16 by post-export processing.
conv_quantizers_to_reenable: list[tuple[nn.Module, str]] = []
if is_fp8_quantized(model):
for module in model.modules():
if not isinstance(module, nn.Conv2d):
continue
quantizer = getattr(module, "weight_quantizer", None)
if quantizer is not None and _is_fp8_quantizer(quantizer):
quantizer.disable()
conv_quantizers_to_reenable.append((module, "weight_quantizer"))
with torch.inference_mode(), autocast, quantizer_context:
additional_kwargs = {}
if not dynamo_export:
Expand All @@ -571,6 +599,10 @@ def get_onnx_bytes_and_metadata(
**additional_kwargs,
)

# Re-enable Conv2d quantizers that were temporarily disabled for FP8 export
for module, qname in conv_quantizers_to_reenable:
getattr(module, qname).enable()

# Check that export worked
assert len(os.listdir(onnx_path)) > 0, "Torch to onnx export failed."

Expand Down Expand Up @@ -617,6 +649,16 @@ def get_onnx_bytes_and_metadata(

onnx_opt_graph = remove_redundant_casts(onnx_opt_graph)

# Fix remaining FP32/FP16 mismatches AFTER remove_redundant_casts.
# Only needed for the convert_float_to_float16 path (FP8/MXFP8/NVFP4) where
# blocked QDQ ops produce FP32 that flows into nodes with FP16 inputs.
# Must run after remove_redundant_casts because that function uses unreliable
# value_info metadata and would incorrectly remove the Cast nodes we insert.
if weights_dtype in ["fp16", "bf16"] and (
is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model)
):
onnx_opt_graph = fix_fp16_fp32_mismatches(onnx_opt_graph)

# TensorRT expects all scales to be postive
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)

Expand Down
Loading
Loading