Skip to content

Add EfficientViT support for torch_onnx quantization workflow#1254

Open
ajrasane wants to merge 1 commit intomainfrom
ajrasane/efficient_vit
Open

Add EfficientViT support for torch_onnx quantization workflow#1254
ajrasane wants to merge 1 commit intomainfrom
ajrasane/efficient_vit

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented Apr 14, 2026

Summary

  • Add end-to-end support for efficientvit_l2 (Conv2d-heavy timm model) in the torch_onnx quantization → ONNX export → TRT engine pipeline
  • Fix several ONNX export and TRT compatibility issues for models with FP8-quantized Conv2d layers
  • Add fix_fp16_fp32_mismatches() utility to resolve FP32/FP16 type mismatches after blocked-op FP16 conversion

Key changes

modelopt infrastructure fixes

  • modelopt/torch/_deploy/utils/torch_onnx.py: Disable FP8 autocast and Conv2d weight quantizer during ONNX export (FP8 custom ops produce dynamic shapes incompatible with Conv kernel shape requirement); add fix_fp16_fp32_mismatches() call after FP16 conversion; relax is_fp8_quantized() to detect mixed-precision models
  • modelopt/onnx/utils.py: Add fix_fp16_fp32_mismatches() function that propagates real element types through the graph and inserts Cast nodes to resolve FP32/FP16 mismatches
  • modelopt/torch/quantization/export_onnx.py: Extend configure_linear_module_onnx_quantizers() to handle non-Linear modules with block-quantized input quantizers (e.g., pooling layers in EfficientViT)

Example script fixes

  • examples/torch_onnx/torch_quant_to_onnx.py: Add _disable_conv2d_dynamic_quantizers() for TRT 4D tensor limitation; set calibration algorithm for MXFP8 Conv2d FP8 overrides; add global_pool to filter_func; always load calibration data

Supported quantization modes for efficientvit_l2

Mode Status
FP8
INT8
MXFP8
NVFP4
INT4_AWQ ❌ (pre-existing limitation)
Auto ❌ (Conv2d FP8 input/weight type mismatch)

Test plan

  • All 5 efficientvit_l2 modes pass (fp8, int8, mxfp8, nvfp4 + auto skipped)
  • All 15 existing model tests pass (vit_tiny, swin_tiny, swinv2_tiny × 5 modes)
  • Full regression suite: 19 passed, 1 skipped, 0 failed (~25 min)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added efficientvit_l2 model support with FP8, INT8, MXFP8, and NVFP4 quantization modes.
    • Introduced --trt_build and --trt_builder_opt_level CLI flags for TensorRT engine build verification.
    • Added automatic FP16/FP32 type mismatch correction in ONNX model exports.
  • Improvements

    • Enhanced FP8 quantization detection and handling for improved export compatibility.
    • Expanded block quantization support for non-Linear modules.
    • Improved calibration handling for low-bit quantization modes.

Add end-to-end support for efficientvit_l2 (Conv2d-heavy timm model) in
the torch_onnx quantization-to-ONNX-to-TRT pipeline. This required
several fixes to handle Conv2d layers with FP8 quantization:

- Disable FP8 autocast during ONNX export to avoid dynamic shape issues
- Disable Conv2d FP8 weight quantizer during ONNX export (TRT_FP8 custom
  ops produce dynamic shapes incompatible with ONNX Conv kernel shape
  requirement)
- Add fix_fp16_fp32_mismatches() to insert Cast nodes resolving FP32/FP16
  type mismatches after blocked-op FP16 conversion
- Extend configure_linear_module_onnx_quantizers() to handle non-Linear
  modules with block-quantized input quantizers (e.g., pooling layers)
- Add _disable_conv2d_dynamic_quantizers() to disable NVFP4/MXFP8 dynamic
  quantizers on Conv2d (TRT dynamic quantize requires 2D/3D, Conv2d is 4D)
- Set calibration algorithm for MXFP8 Conv2d FP8 overrides
- Add global_pool to filter_func exclusions
- Relax is_fp8_quantized() to detect models with only input_quantizer FP8

Supported modes: FP8, INT8, MXFP8, NVFP4. Auto mode excluded due to
Conv2d FP8 input/weight type mismatch in TRT stronglyTyped.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane requested review from a team as code owners April 14, 2026 07:01
@ajrasane ajrasane requested a review from cjluo-nv April 14, 2026 07:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 14, 2026

📝 Walkthrough

Walkthrough

This PR extends ONNX export support for quantized models, particularly focusing on FP8 quantization and type compatibility. Changes include adding TensorRT build verification capabilities, fixing FP16/FP32 type mismatches in ONNX graphs, improving FP8 quantization detection, and adding support for the efficientvit_l2 vision model across the quantization and export pipeline.

Changes

Cohort / File(s) Summary
Model Support Documentation
examples/torch_onnx/README.md, tests/_test_utils/torch/vision_models.py
Added efficientvit_l2 model to supported ONNX export table and test vision models list with FP8/INT8/MXFP8/NVFP4 quantization support.
ONNX Export Script & Integration
examples/torch_onnx/torch_quant_to_onnx.py
Introduced TRT engine build verification via --trt_build and --trt_builder_opt_level CLI flags; added subprocess-based trtexec invocation with error detection. Updated get_quant_config() to enforce algorithm="max" for MXFP8/NVFP4, expanded quantization filters to exclude global_pool, and added _disable_conv2d_dynamic_quantizers() to handle Conv2d quantizer block sizes. Changed non-auto quantization to always load calibration data.
ONNX Utilities
modelopt/onnx/utils.py
Added new fix_fp16_fp32_mismatches() utility to propagate and correct tensor types through ONNX graphs, detecting and fixing mixed FP32/FP16 inputs to elementwise operations via inserted Cast nodes.
Quantization Export Logic
modelopt/torch/_deploy/utils/torch_onnx.py, modelopt/torch/quantization/export_onnx.py
Refactored FP8 quantization detection with new _is_fp8_quantizer() helper; broadened is_fp8_quantized() to recognize any enabled FP8-configured quantizer. Updated ONNX export to disable Conv2d weight quantizers during export, added autocast logic for FP8 models, and integrated fix_fp16_fp32_mismatches() post-processing for INT4/MXFP8/FP8 quantized models. Enhanced configure_linear_module_onnx_quantizers() to handle block-quantized non-Linear modules.
Test Updates
tests/examples/torch_onnx/test_torch_quant_to_onnx.py
Removed local TRT verification helper; added efficientvit_l2 to test matrix with auto-quantization exclusion via _AUTO_EXCLUDED_MODELS. Updated test harness to invoke --trt_build flag and pass trt_builder_opt_level computed from bit-width modes to the quantization script.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add EfficientViT support for torch_onnx quantization workflow' directly and clearly summarizes the main change—adding support for the efficientvit_l2 model through the quantization and ONNX export pipeline.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR contains no critical security anti-patterns. Subprocess usage correctly employs subprocess.run() with structured argument list (no shell=True), preventing command injection vulnerabilities.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ajrasane/efficient_vit

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1254/

Built to branch gh-pages at 2026-04-14 07:06 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/_deploy/utils/torch_onnx.py (1)

572-605: ⚠️ Potential issue | 🟠 Major

Re-enable Conv2d quantizers in a finally block.

If torch.onnx.export() fails, Lines 603-604 never run and the caller gets back a mutated model with FP8 Conv2d weight quantizers still disabled.

🔒 Suggested fix
-    with torch.inference_mode(), autocast, quantizer_context:
-        additional_kwargs = {}
-        if not dynamo_export:
-            additional_kwargs["dynamic_axes"] = dynamic_axes
-        torch.onnx.export(
-            model,
-            dummy_input,
-            onnx_save_path,
-            input_names=input_names,
-            output_names=output_names,
-            opset_version=onnx_opset,
-            dynamo=dynamo_export,
-            **additional_kwargs,
-        )
-
-    # Re-enable Conv2d quantizers that were temporarily disabled for FP8 export
-    for module, qname in conv_quantizers_to_reenable:
-        getattr(module, qname).enable()
+    try:
+        with torch.inference_mode(), autocast, quantizer_context:
+            additional_kwargs = {}
+            if not dynamo_export:
+                additional_kwargs["dynamic_axes"] = dynamic_axes
+            torch.onnx.export(
+                model,
+                dummy_input,
+                onnx_save_path,
+                input_names=input_names,
+                output_names=output_names,
+                opset_version=onnx_opset,
+                dynamo=dynamo_export,
+                **additional_kwargs,
+            )
+    finally:
+        for module, qname in conv_quantizers_to_reenable:
+            getattr(module, qname).enable()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/_deploy/utils/torch_onnx.py` around lines 572 - 605, The code
temporarily disables Conv2d weight quantizers (see conv_quantizers_to_reenable
and _is_fp8_quantizer) but re-enables them only after torch.onnx.export(), so if
export raises the model remains mutated; wrap the export and the re-enable loop
in a try/finally (or move the re-enable into a finally block) so that for every
(module, qname) collected you call getattr(module, qname).enable() in the
finally regardless of exceptions, ensuring conv_quantizers_to_reenable is always
restored even when torch.onnx.export() fails.
tests/examples/torch_onnx/test_torch_quant_to_onnx.py (1)

40-61: ⚠️ Potential issue | 🟠 Major

This matrix is too heavy for tests/examples/.

Line 60 now enables a full TensorRT build for every model/mode pair, and the added EfficientViT case expands that to 19 example runs. That matches the PR’s ~25 minute regression time, which is well past the expected budget for this test directory. Please keep a small smoke matrix here and move the exhaustive TRT-build coverage to a slower/nightly path.

As per coding guidelines, tests/examples/**/*.py: Integration tests in tests/examples/ should not take more than a few minutes to run.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/torch_onnx/test_torch_quant_to_onnx.py` around lines 40 - 61,
The test currently runs a full TensorRT build for every combination in
test_torch_onnx (function test_torch_onnx using _MODELS, _QUANT_MODES,
_LOW_BIT_MODES, _AUTO_EXCLUDED_MODELS and run_example_command) which makes the
example tests too slow; restrict the matrix by either (A) limiting the
parametrization to a small smoke set (e.g., define a small SMOKE_MODELS list and
use that instead of _MODELS) or (B) gate the TensorRT build so it only runs in
slow/nightly runs (e.g., add a check around appending "--trt_build" that only
sets it when an env var like RUN_TRT or pytest marker e.g.,
pytestconfig.getoption("--run-trt")/os.environ["RUN_TRT"] is true), and keep the
default test path fast by leaving calibration_data_size/num_score_steps small;
update test_torch_onnx accordingly and ensure run_example_command is invoked for
the reduced set or gated TRT flag so regular example tests finish quickly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 417-427: The code always calls load_calibration_data which causes
timm.create_model(..., pretrained=True) to be instantiated earlier and ignores
--no_pretrained; fix by guarding the calibration load so it only runs when
pretrained weights are allowed: change the call site of load_calibration_data to
run only if not args.no_pretrained (or if the auto calibration flow explicitly
requires pretrained weights), e.g., wrap the existing load_calibration_data
invocation in a conditional that checks args.no_pretrained (and any auto-mode
flag you use) so timm.create_model is not triggered when --no_pretrained is set.

In `@modelopt/onnx/utils.py`:
- Around line 1528-1578: The propagation code incorrectly treats comparison ops
("Equal", "Less", "Greater") like numeric elementwise ops and assigns
FLOAT/FLOAT16; update the Step 1 propagation so that when node.op_type is one of
"Equal", "Less", "Greater" you set real_type[out] = BOOL for each node.output
(similar to how "Cast" and _BLOCKED_OPS are handled) to ensure comparison
outputs are recorded as boolean tensors and downstream nodes like "Where"
receive correct mask types.

In `@modelopt/torch/quantization/export_onnx.py`:
- Around line 657-675: The context manager
configure_linear_module_onnx_quantizers mutates
module.input_quantizer._onnx_quantizer_type and
module.weight_quantizer._onnx_quantizer_type (and other modules'
input_quantizer._onnx_quantizer_type) but does not restore prior values on exit;
modify it to record the previous _onnx_quantizer_type for each quantizer you
change (use the module and quantizer identity to store originals) and restore
those values in the finally/after-yield block so the original quantizer state is
preserved outside the context manager; ensure you handle missing attributes
safely and restore both input_quantizer and weight_quantizer entries for
instances found in configure_linear_module_onnx_quantizers.

---

Outside diff comments:
In `@modelopt/torch/_deploy/utils/torch_onnx.py`:
- Around line 572-605: The code temporarily disables Conv2d weight quantizers
(see conv_quantizers_to_reenable and _is_fp8_quantizer) but re-enables them only
after torch.onnx.export(), so if export raises the model remains mutated; wrap
the export and the re-enable loop in a try/finally (or move the re-enable into a
finally block) so that for every (module, qname) collected you call
getattr(module, qname).enable() in the finally regardless of exceptions,
ensuring conv_quantizers_to_reenable is always restored even when
torch.onnx.export() fails.

In `@tests/examples/torch_onnx/test_torch_quant_to_onnx.py`:
- Around line 40-61: The test currently runs a full TensorRT build for every
combination in test_torch_onnx (function test_torch_onnx using _MODELS,
_QUANT_MODES, _LOW_BIT_MODES, _AUTO_EXCLUDED_MODELS and run_example_command)
which makes the example tests too slow; restrict the matrix by either (A)
limiting the parametrization to a small smoke set (e.g., define a small
SMOKE_MODELS list and use that instead of _MODELS) or (B) gate the TensorRT
build so it only runs in slow/nightly runs (e.g., add a check around appending
"--trt_build" that only sets it when an env var like RUN_TRT or pytest marker
e.g., pytestconfig.getoption("--run-trt")/os.environ["RUN_TRT"] is true), and
keep the default test path fast by leaving calibration_data_size/num_score_steps
small; update test_torch_onnx accordingly and ensure run_example_command is
invoked for the reduced set or gated TRT flag so regular example tests finish
quickly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 3d04be9f-ba1c-4c97-9046-7a009f75a2e0

📥 Commits

Reviewing files that changed from the base of the PR and between b6c6ec3 and eaa16a6.

📒 Files selected for processing (7)
  • examples/torch_onnx/README.md
  • examples/torch_onnx/torch_quant_to_onnx.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • tests/_test_utils/torch/vision_models.py
  • tests/examples/torch_onnx/test_torch_quant_to_onnx.py

Comment on lines +417 to +427
# Always load calibration data. Even though MXFP8 uses dynamic quantization
# and doesn't strictly require calibration, the Conv2d FP8 overrides (applied
# by get_quant_config for MXFP8/NVFP4) use static FP8 quantization which
# needs calibration data to compute amax values.
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

--no_pretrained is no longer honored in the standard path.

Because Line 421 now always calls load_calibration_data(), the non-auto flow will still instantiate timm.create_model(..., pretrained=True) from Line 135 even when --no_pretrained is set. That turns a local/offline smoke run into a networked weights fetch.

💡 Suggested fix
-def load_calibration_data(model_name, data_size, batch_size, device, with_labels=False):
+def load_calibration_data(
+    model_name, data_size, batch_size, device, with_labels=False, pretrained=True
+):
@@
-    model = timm.create_model(model_name, pretrained=True, num_classes=1000)
+    model = timm.create_model(model_name, pretrained=pretrained, num_classes=1000)
         data_loader = load_calibration_data(
             args.timm_model_name,
             args.calibration_data_size,
             args.batch_size,
             device,
             with_labels=False,
+            pretrained=not args.no_pretrained,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 417 - 427, The code
always calls load_calibration_data which causes timm.create_model(...,
pretrained=True) to be instantiated earlier and ignores --no_pretrained; fix by
guarding the calibration load so it only runs when pretrained weights are
allowed: change the call site of load_calibration_data to run only if not
args.no_pretrained (or if the auto calibration flow explicitly requires
pretrained weights), e.g., wrap the existing load_calibration_data invocation in
a conditional that checks args.no_pretrained (and any auto-mode flag you use) so
timm.create_model is not triggered when --no_pretrained is set.

Comment on lines +1528 to +1578
# Ops whose data inputs must all have the same type in TRT stronglyTyped mode.
_ELEMENTWISE_OPS = {
"Add", "Sub", "Mul", "Div", "Pow", "Min", "Max", "Equal", "Less",
"Greater", "Where", "Sum", "Mean", "Concat",
}

# Ops that are FP32-only (QDQ) — never cast their I/O.
_BLOCKED_OPS = {"QuantizeLinear", "DequantizeLinear"}

# --- Step 1: Propagate real element types through the graph. ---
real_type: dict[str, int] = {}

# Seed from graph inputs and initializers (these are authoritative).
for inp in model.graph.input:
real_type[inp.name] = inp.type.tensor_type.elem_type
for init in model.graph.initializer:
real_type[init.name] = init.data_type

# Process nodes in topological order.
for node in model.graph.node:
if node.op_type == "Constant":
for attr in node.attribute:
if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
for out in node.output:
real_type[out] = attr.t.data_type
continue

if node.op_type == "Cast":
cast_to = get_cast_to_type(node)
for out in node.output:
real_type[out] = cast_to
continue

if node.op_type in _BLOCKED_OPS:
for out in node.output:
real_type[out] = FLOAT
continue

# For other ops: output type matches the predominant data-input type.
data_types = []
for inp_name in node.input:
if inp_name and inp_name in real_type and real_type[inp_name] in (FLOAT, FLOAT16):
data_types.append(real_type[inp_name])

if data_types:
out_type = FLOAT if FLOAT in data_types else FLOAT16
else:
out_type = FLOAT16

for out in node.output:
real_type[out] = out_type
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
import onnx

for op in ("Equal", "Less", "Greater"):
    schema = onnx.defs.get_schema(op, 20)
    print(f"\n{op}")
    for out in schema.outputs:
        print(" output:", out.name, out.type_str)
    for constraint in schema.type_constraints:
        print(" constraint:", constraint.type_param_str, list(constraint.allowed_type_strs))
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 182


🏁 Script executed:

# First, examine the actual code in the file to understand the current state
cd /tmp && find . -name "utils.py" -path "*/modelopt/onnx/*" 2>/dev/null | head -1

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's locate and examine the file
find . -name "utils.py" -path "*/modelopt/onnx/*" | head -1 | xargs wc -l

Repository: NVIDIA/Model-Optimizer

Length of output: 103


🏁 Script executed:

# Examine the specific lines mentioned in the review (1528-1578)
find . -name "utils.py" -path "*/modelopt/onnx/*" | head -1 | xargs sed -n '1500,1600p'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Try installing onnx and running the verification script with more context
pip install -q onnx 2>&1 | tail -5
python - <<'PY'
try:
    import onnx
    print("ONNX installed successfully")
    for op in ("Equal", "Less", "Greater", "Add"):
        schema = onnx.defs.get_schema(op, 20)
        print(f"\n{op}:")
        for out in schema.outputs:
            print(f"  output: {out.name}, type: {out.type_str}")
except Exception as e:
    print(f"Error: {e}")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 464


🏁 Script executed:

# Get full ONNX schema with all type constraints for comparison ops
python - <<'PY'
import onnx

for op in ("Equal", "Less", "Greater"):
    schema = onnx.defs.get_schema(op, 20)
    print(f"\n{op}:")
    print(f"  Outputs:")
    for out in schema.outputs:
        print(f"    {out.name}: {out.type_str}")
    print(f"  Type Constraints:")
    for constraint in schema.type_constraints:
        print(f"    {constraint.type_param_str}: {constraint.allowed_type_strs}")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 982


🏁 Script executed:

# Now locate and read the actual file with the function
fd utils.py modelopt/onnx | head -1 | xargs wc -l

Repository: NVIDIA/Model-Optimizer

Length of output: 101


🏁 Script executed:

# Read the actual function around lines 1528-1578
fd utils.py modelopt/onnx | head -1 | xargs sed -n '1520,1590p'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Find the correct file - look for utils.py in modelopt/onnx (not just autocast subdir)
find modelopt/onnx -name "utils.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 165


🏁 Script executed:

# Also search for the function name to locate it
rg "fix_fp16_fp32_mismatches" modelopt/

Repository: NVIDIA/Model-Optimizer

Length of output: 343


🏁 Script executed:

# Read the actual function from the correct file
sed -n '1500,1650p' modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 5799


🏁 Script executed:

# Verify if there are other comparison operators that should also output BOOL
python - <<'PY'
import onnx

comparison_ops = ["Equal", "Less", "Greater", "LessOrEqual", "GreaterOrEqual"]
for op in comparison_ops:
    try:
        schema = onnx.defs.get_schema(op, 20)
        print(f"{op}:")
        for constraint in schema.type_constraints:
            if constraint.type_param_str == "T1" or "bool" in str(constraint.allowed_type_strs):
                print(f"  {constraint.type_param_str}: {constraint.allowed_type_strs}")
    except Exception as e:
        print(f"{op}: Error - {e}")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 483


🏁 Script executed:

# Check if LessOrEqual and GreaterOrEqual are also in the code
grep -n "LessOrEqual\|GreaterOrEqual" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Double-check the logic of the fallback case in the code
# to confirm comparison ops are being treated like regular ops
sed -n '1555,1575p' modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 804


Special-case comparison outputs as BOOL.

The type propagation in Step 1 treats Equal, Less, and Greater outputs using the default fallback (predominant input type), but ONNX comparison operators always output tensor(bool). This causes their outputs to be incorrectly labeled as FLOAT or FLOAT16, which mislabels boolean masks and can cause downstream Where conditions to receive an invalid cast insertion.

🐛 Suggested fix
     for node in model.graph.node:
         if node.op_type == "Constant":
@@
         if node.op_type in _BLOCKED_OPS:
             for out in node.output:
                 real_type[out] = FLOAT
             continue
+
+        if node.op_type in {"Equal", "Less", "Greater"}:
+            for out in node.output:
+                real_type[out] = onnx.TensorProto.BOOL
+            continue
 
         # For other ops: output type matches the predominant data-input type.
         data_types = []
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/utils.py` around lines 1528 - 1578, The propagation code
incorrectly treats comparison ops ("Equal", "Less", "Greater") like numeric
elementwise ops and assigns FLOAT/FLOAT16; update the Step 1 propagation so that
when node.op_type is one of "Equal", "Less", "Greater" you set real_type[out] =
BOOL for each node.output (similar to how "Cast" and _BLOCKED_OPS are handled)
to ensure comparison outputs are recorded as boolean tensors and downstream
nodes like "Where" receive correct mask types.

Comment on lines 657 to 675
@contextlib.contextmanager
def configure_linear_module_onnx_quantizers(model):
"""Sets the onnx export attributes for the given model."""
"""Sets the onnx export attributes for the given model.

For Linear modules, sets both input and weight quantizer types.
For other modules with block-quantized input_quantizers (e.g., pooling layers
in models like EfficientViT), sets the input quantizer to "dynamic" to prevent
TRT_FP4QDQ static export for activations.
"""
for _, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.input_quantizer._onnx_quantizer_type = "dynamic"
module.weight_quantizer._onnx_quantizer_type = "static"
elif (
hasattr(module, "input_quantizer")
and getattr(module.input_quantizer, "block_sizes", None)
):
module.input_quantizer._onnx_quantizer_type = "dynamic"
yield
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Restore _onnx_quantizer_type on exit.

Lines 670-674 now mutate more quantizers, but the context manager still never puts their previous values back. _onnx_quantizer_type is read during quantizer forward, so this leaks export-only behavior into later inference/export passes.

♻️ Suggested fix
 `@contextlib.contextmanager`
 def configure_linear_module_onnx_quantizers(model):
+    previous_values = []
     """Sets the onnx export attributes for the given model.
@@
     for _, module in model.named_modules():
         if isinstance(module, torch.nn.Linear):
+            previous_values.extend(
+                [
+                    (
+                        module.input_quantizer,
+                        getattr(module.input_quantizer, "_onnx_quantizer_type", None),
+                        hasattr(module.input_quantizer, "_onnx_quantizer_type"),
+                    ),
+                    (
+                        module.weight_quantizer,
+                        getattr(module.weight_quantizer, "_onnx_quantizer_type", None),
+                        hasattr(module.weight_quantizer, "_onnx_quantizer_type"),
+                    ),
+                ]
+            )
             module.input_quantizer._onnx_quantizer_type = "dynamic"
             module.weight_quantizer._onnx_quantizer_type = "static"
         elif (
             hasattr(module, "input_quantizer")
             and getattr(module.input_quantizer, "block_sizes", None)
         ):
+            previous_values.append(
+                (
+                    module.input_quantizer,
+                    getattr(module.input_quantizer, "_onnx_quantizer_type", None),
+                    hasattr(module.input_quantizer, "_onnx_quantizer_type"),
+                )
+            )
             module.input_quantizer._onnx_quantizer_type = "dynamic"
-    yield
+    try:
+        yield
+    finally:
+        for quantizer, old_value, existed in previous_values:
+            if existed:
+                quantizer._onnx_quantizer_type = old_value
+            elif hasattr(quantizer, "_onnx_quantizer_type"):
+                delattr(quantizer, "_onnx_quantizer_type")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/export_onnx.py` around lines 657 - 675, The
context manager configure_linear_module_onnx_quantizers mutates
module.input_quantizer._onnx_quantizer_type and
module.weight_quantizer._onnx_quantizer_type (and other modules'
input_quantizer._onnx_quantizer_type) but does not restore prior values on exit;
modify it to record the previous _onnx_quantizer_type for each quantizer you
change (use the module and quantizer identity to store originals) and restore
those values in the finally/after-yield block so the original quantizer state is
preserved outside the context manager; ensure you handle missing attributes
safely and restore both input_quantizer and weight_quantizer entries for
instances found in configure_linear_module_onnx_quantizers.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 14, 2026

Codecov Report

❌ Patch coverage is 98.82353% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 77.50%. Comparing base (202c3d3) to head (eaa16a6).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/utils.py 98.46% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1254      +/-   ##
==========================================
+ Coverage   76.90%   77.50%   +0.59%     
==========================================
  Files         350      350              
  Lines       40524    40610      +86     
==========================================
+ Hits        31166    31473     +307     
+ Misses       9358     9137     -221     
Flag Coverage Δ
examples 43.93% <98.82%> (+1.30%) ⬆️
gpu 57.42% <9.41%> (-0.11%) ⬇️
unit 55.51% <9.41%> (-0.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant