Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/torch_onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ python torch_quant_to_onnx.py \
| [vit_base_patch16_224](https://huggingface.co/timm/vit_base_patch16_224.augreg_in21k_ft_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swin_tiny_patch4_window7_224](https://huggingface.co/timm/swin_tiny_patch4_window7_224.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [swinv2_tiny_window8_256](https://huggingface.co/timm/swinv2_tiny_window8_256.ms_in1k) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [resnet50](https://huggingface.co/timm/resnet50.a1_in1k) | ✅ | ✅ | ✅ | ✅ | | ✅ |

## Resources

Expand Down
139 changes: 122 additions & 17 deletions examples/torch_onnx/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import json
import re
import subprocess
import sys
import warnings
from pathlib import Path
Expand All @@ -35,13 +36,17 @@
import modelopt.torch.quantization as mtq

"""
This script is used to quantize a timm model using dynamic quantization like MXFP8 or NVFP4,
or using auto quantization for optimal per-layer quantization.
Quantize a timm vision model and export to ONNX for TensorRT deployment.

Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.

The script will:
1. Given the model name, create a timm torch model.
2. Quantize the torch model in MXFP8, NVFP4, INT4_AWQ, or AUTO mode.
3. Export the quantized torch model to ONNX format.
1. Load a pretrained timm model (e.g., ViT, Swin, ResNet).
2. Quantize the model using the specified mode. For models with Conv2d layers,
Conv2d quantization is automatically overridden for TensorRT compatibility
(FP8 for MXFP8/NVFP4, INT8 for INT4_AWQ).
Comment on lines +39 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t advertise INT4_AWQ as supported end-to-end here.

The PR objectives still call out INT4_AWQ as a known limitation, but this docstring now groups it with the working modes. Please caveat or remove it here so users do not assume this example is expected to succeed in that mode.

✏️ Suggested wording
-Supports FP8, INT8, MXFP8, NVFP4, INT4_AWQ, and AUTO (mixed-precision) quantization modes.
+Supports FP8, INT8, MXFP8, NVFP4, and AUTO (mixed-precision) quantization modes.
+`INT4_AWQ` remains a known limitation for this example.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 38 - 46, The
docstring in torch_quant_to_onnx.py incorrectly groups INT4_AWQ with fully
supported quantization modes; update the top-level description to either remove
INT4_AWQ from the supported list or add a clear caveat that INT4_AWQ is a known
limitation and may not work end-to-end (e.g., "INT4_AWQ is experimental/limited
— see PR objectives for current limitations"), ensuring references to the script
name and the quantization modes (FP8, INT8, MXFP8, NVFP4, INT4_AWQ, AUTO) are
adjusted so users won't assume INT4_AWQ is fully supported.

3. Export the quantized model to ONNX with FP16 weights.
4. Optionally evaluate accuracy on ImageNet-1k before and after quantization.
"""


Expand Down Expand Up @@ -109,7 +114,8 @@ def filter_func(name):
"""Filter function to exclude certain layers from quantization."""
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|downsample).*"
r"pos_embed|time_text_embed|context_embedder|norm_out|x_embedder|patch_embed|cpb_mlp|"
r"downsample|maxpool|global_pool).*"
)
return pattern.match(name) is not None

Expand Down Expand Up @@ -147,6 +153,40 @@ def load_calibration_data(model_name, data_size, batch_size, device, with_labels
)


def _calibrate_uncalibrated_quantizers(model, data_loader):
"""Calibrate FP8 quantizers that weren't calibrated by mtq.quantize().

When MXFP8/NVFP4 modes override Conv2d to FP8, the FP8 quantizers may not
be calibrated because the MXFP8/NVFP4 quantization pipeline skips standard
calibration. This function explicitly calibrates those uncalibrated quantizers.
"""
uncalibrated = []
for _, module in model.named_modules():
for attr_name in ("input_quantizer", "weight_quantizer"):
if not hasattr(module, attr_name):
continue
quantizer = getattr(module, attr_name)
if (
quantizer.is_enabled
and not quantizer.block_sizes
and not hasattr(quantizer, "_amax")
):
quantizer.enable_calib()
uncalibrated.append(quantizer)

if not uncalibrated:
return

model.eval()
with torch.no_grad():
for batch in data_loader:
model(batch)

for quantizer in uncalibrated:
quantizer.disable_calib()
quantizer.load_calib_amax()


def quantize_model(model, config, data_loader=None):
"""Quantize the model using the given config and calibration data."""
if data_loader is not None:
Expand All @@ -159,6 +199,10 @@ def forward_loop(model):
else:
quantized_model = mtq.quantize(model, config)

# Calibrate any FP8 override quantizers that weren't calibrated by mtq.quantize()
if data_loader is not None:
_calibrate_uncalibrated_quantizers(quantized_model, data_loader)

mtq.disable_quantizer(quantized_model, filter_func)
return quantized_model

Expand All @@ -185,6 +229,38 @@ def _disable_inplace_relu(model):
module.inplace = False


def _override_conv2d_to_fp8(model, data_loader):
"""Override Conv2d layers with NVFP4/MXFP8 block quantization to FP8.

TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
This overrides Conv2d block quantizers to FP8 per-tensor and calibrates them.
"""
overridden = []
for _, module in model.named_modules():
if not isinstance(module, torch.nn.Conv2d):
continue
for attr_name in ("input_quantizer", "weight_quantizer"):
if not hasattr(module, attr_name):
continue
quantizer = getattr(module, attr_name)
if quantizer.is_enabled and quantizer.block_sizes:
# Override to FP8 per-tensor
quantizer.block_sizes = None
quantizer._num_bits = (4, 3)
quantizer._axis = None
Comment on lines +248 to +250
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there's a public API for setting quantizer num_bits/axis
rg -n "def.*num_bits|def.*axis" --type=py modelopt/torch/quantization/

Repository: NVIDIA/Model-Optimizer

Length of output: 3259


🏁 Script executed:

#!/bin/bash
# Get the file and examine the problematic lines with context
cat -n examples/torch_onnx/torch_quant_to_onnx.py | sed -n '235,260p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1289


🏁 Script executed:

#!/bin/bash
# Examine the public num_bits and axis property implementations
cat -n modelopt/torch/quantization/nn/modules/tensor_quantizer.py | sed -n '272,285p'
cat -n modelopt/torch/quantization/nn/modules/tensor_quantizer.py | sed -n '372,385p'

Repository: NVIDIA/Model-Optimizer

Length of output: 967


Use public property setters for num_bits and axis to maintain calibrator synchronization.

At lines 249-250, directly assigning to _num_bits and _axis bypasses the property setters, which synchronize these values with the calibrator's internal state. This creates an inconsistency where the quantizer and calibrator have mismatched configuration.

Change:

quantizer._num_bits = (4, 3)
quantizer._axis = None

To:

quantizer.num_bits = (4, 3)
quantizer.axis = None

Line 248 correctly uses the public block_sizes property; apply the same pattern here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 248 - 250, Replace
the direct internal assignments to quantizer._num_bits and quantizer._axis with
the public property setters so the calibrator stays synchronized: use
quantizer.num_bits = (4, 3) and quantizer.axis = None instead of writing to
_num_bits and _axis (keeping quantizer.block_sizes assignment as-is); update
references in the same snippet to call the num_bits and axis properties on the
quantizer object to ensure proper calibrator/internal state updates.

quantizer.enable_calib()
overridden.append(quantizer)

if overridden:
model.eval()
with torch.no_grad():
for batch in data_loader:
model(batch["image"])
for quantizer in overridden:
quantizer.disable_calib()
quantizer.load_calib_amax()
Comment on lines +232 to +261
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Apply the Conv2d TRT restriction before the auto-quant search.

mtq.auto_quantize() optimizes under the requested effective_bits budget and returns a search_state for the formats it actually evaluated. Rewriting any block-quantized Conv2d to FP8 afterwards means the exported model no longer matches that search space: Conv layers selected as NVFP4/MXFP8/INT4 now pay FP8 cost instead, so the final model can overshoot the target budget and the returned search_state becomes stale. Please push the Conv2d compatibility rule into the candidate configs before calling mtq.auto_quantize(), or recompute/validate the post-override budget before returning.

Also applies to: 307-309

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 227 - 256, The
override of Conv2d block quantizers to FP8 in _override_conv2d_to_fp8 must not
be applied after mtq.auto_quantize() because mtq.auto_quantize() returns a
search_state based on the candidate configs; either incorporate the
Conv2d->FP8/disable-block-quantization rule into the candidate generation passed
into mtq.auto_quantize() (so those Conv2d layers are treated as FP8 during the
search) or, if you must keep the override path, recompute and validate the
effective_bits/budget and update/refresh the returned search_state after running
_override_conv2d_to_fp8 so the final model’s budget reflects FP8 costs for
Conv2d (reference symbols: _override_conv2d_to_fp8, mtq.auto_quantize,
search_state).



def auto_quantize_model(
model,
data_loader,
Expand Down Expand Up @@ -233,6 +309,10 @@ def auto_quantize_model(
verbose=True,
)

# Override Conv2d layers that got NVFP4/MXFP8 to FP8 for TRT compatibility.
# TRT DynamicQuantize requires 2D/3D input, but Conv2d operates on 4D tensors.
_override_conv2d_to_fp8(quantized_model, data_loader)

# Disable quantization for specified layers
mtq.disable_quantizer(quantized_model, filter_func)

Expand Down Expand Up @@ -320,6 +400,11 @@ def main():
default=128,
help="Number of scoring steps for auto quantization. Default is 128.",
)
parser.add_argument(
"--trt_build",
action="store_true",
help="Build a TensorRT engine from the exported ONNX model using trtexec.",
)
parser.add_argument(
"--no_pretrained",
action="store_true",
Expand Down Expand Up @@ -378,18 +463,18 @@ def main():
args.num_score_steps,
)
else:
# Standard quantization - only load calibration data if needed
# Standard quantization - load calibration data
# Note: MXFP8 is dynamic and does not need calibration itself, but when
# Conv2d layers are overridden to FP8 (for TRT compatibility), those FP8
# quantizers require calibration data.
config = get_quant_config(args.quantize_mode)
if args.quantize_mode == "mxfp8":
data_loader = None
else:
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)
data_loader = load_calibration_data(
args.timm_model_name,
args.calibration_data_size,
args.batch_size,
device,
with_labels=False,
)

quantized_model = quantize_model(model, config, data_loader)

Expand Down Expand Up @@ -421,6 +506,26 @@ def main():

print(f"Quantized ONNX model is saved to {args.onnx_save_path}")

if args.trt_build:
build_trt_engine(args.onnx_save_path)


def build_trt_engine(onnx_path):
"""Build a TensorRT engine from the exported ONNX model using trtexec."""
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
"--builderOptimizationLevel=4",
]
print(f"\nBuilding TensorRT engine: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
if result.returncode != 0:
raise RuntimeError(
f"TensorRT engine build failed for {onnx_path}:\n{result.stdout}\n{result.stderr}"
)
print("TensorRT engine build succeeded.")


if __name__ == "__main__":
main()
44 changes: 41 additions & 3 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Utility functions related to Onnx."""

import base64
import contextlib
import inspect
import json
import logging
Expand Down Expand Up @@ -402,6 +403,29 @@ def is_fp8_quantized(model: nn.Module) -> bool:
return False


@contextlib.contextmanager
def _disable_fp8_conv_weight_quantizers(model: nn.Module):
"""Temporarily disable FP8 weight quantizers on Conv layers during ONNX export.

The TorchScript ONNX exporter requires static kernel shapes for Conv operations,
but FP8 weight quantization (TRT_FP8QuantizeLinear -> TRT_FP8DequantizeLinear)
produces dynamic-shape outputs that break this requirement. Disabling Conv weight
quantizers during export allows the Conv to export with static-shape FP16/FP32
weights. Conv activations still have FP8 QDQ nodes (input quantizers remain enabled).
"""
disabled = []
for _, module in model.named_modules():
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if hasattr(module, "weight_quantizer") and module.weight_quantizer.is_enabled:
module.weight_quantizer.disable()
disabled.append(module)
try:
yield
finally:
for module in disabled:
module.weight_quantizer.enable()


def quantize_weights(model: nn.Module, onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Real quantizes the weights in the onnx model.

Expand Down Expand Up @@ -522,7 +546,11 @@ def get_onnx_bytes_and_metadata(
input_none_names = list(set(tree_spec_input.names) - set(input_names))

use_torch_autocast = not (
is_fp4_quantized(model) or is_mxfp8_quantized(model) or weights_dtype == "fp32"
is_fp4_quantized(model)
or is_mxfp8_quantized(model)
or is_fp8_quantized(model)
or is_int8_quantized(model)
or weights_dtype == "fp32"
)
autocast = torch.autocast("cuda") if use_torch_autocast else nullcontext()

Expand Down Expand Up @@ -556,7 +584,12 @@ def get_onnx_bytes_and_metadata(
if is_fp4_quantized(model) or is_mxfp8_quantized(model)
else nullcontext()
)
with torch.inference_mode(), autocast, quantizer_context:
# Disable FP8 Conv weight quantizers: TorchScript ONNX exporter requires static
# kernel shapes, but FP8 DequantizeLinear produces dynamic shapes.
conv_wq_context = (
_disable_fp8_conv_weight_quantizers(model) if is_fp8_quantized(model) else nullcontext()
)
with torch.inference_mode(), autocast, quantizer_context, conv_wq_context:
additional_kwargs = {}
if not dynamo_export:
additional_kwargs["dynamic_axes"] = dynamic_axes
Expand Down Expand Up @@ -598,7 +631,12 @@ def get_onnx_bytes_and_metadata(
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)

if weights_dtype in ["fp16", "bf16"]:
if is_int4_quantized(model) or is_mxfp8_quantized(model) or is_fp8_quantized(model):
if (
is_int4_quantized(model)
or is_mxfp8_quantized(model)
or is_fp8_quantized(model)
or is_int8_quantized(model)
):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
Expand Down
15 changes: 13 additions & 2 deletions modelopt/torch/quantization/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,20 @@ def export_fp4(

@contextlib.contextmanager
def configure_linear_module_onnx_quantizers(model):
"""Sets the onnx export attributes for the given model."""
"""Sets the onnx export attributes for the given model.

For modules with block quantization (NVFP4/MXFP8):
- Weight quantizers use "static" export (TRT_FP4QDQ for NVFP4, DQ-only for MXFP8)
- Input/activation quantizers use "dynamic" export (TRT_FP4DynamicQuantize, etc.)

This must be set for ALL modules with block quantization, not just nn.Linear,
because models like ResNet have non-Linear modules (e.g., MaxPool2d) with NVFP4/MXFP8
input quantizers that would otherwise default to the static path and produce
TRT_FP4QDQ nodes on activations (which the NVFP4 exporter cannot handle).
"""
for _, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(module, "input_quantizer") and module.input_quantizer.block_sizes:
module.input_quantizer._onnx_quantizer_type = "dynamic"
if hasattr(module, "weight_quantizer") and module.weight_quantizer.block_sizes:
module.weight_quantizer._onnx_quantizer_type = "static"
yield
1 change: 1 addition & 0 deletions tests/_test_utils/torch/vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_model_and_input(on_gpu: bool = False):
# "dm_nfnet_f0",
"efficientnet_b0",
"swin_tiny_patch4_window7_224",
"resnet50",
],
_create_timm_fn,
),
Expand Down
36 changes: 2 additions & 34 deletions tests/examples/torch_onnx/test_torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
# limitations under the License.


import os
import subprocess

import pytest
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command

Expand All @@ -28,42 +25,16 @@
"vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'),
"swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'),
"swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'),
"resnet50": ("resnet50", None),
}

# Builder optimization level: 4 for low-bit modes, 3 otherwise
_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"}


def _verify_trt_engine_build(onnx_save_path, quantize_mode):
"""Verify the exported ONNX model can be compiled into a TensorRT engine."""
example_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx"
)
onnx_path = os.path.join(example_dir, onnx_save_path)
assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}"

opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3"
cmd = [
"trtexec",
f"--onnx={onnx_path}",
"--stronglyTyped",
f"--builderOptimizationLevel={opt_level}",
]

result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
assert result.returncode == 0, (
f"TensorRT engine build failed for {onnx_save_path} "
f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}"
)


@pytest.mark.parametrize("quantize_mode", _QUANT_MODES)
@pytest.mark.parametrize("model_key", list(_MODELS))
def test_torch_onnx(model_key, quantize_mode):
timm_model_name, model_kwargs = _MODELS[model_key]
onnx_save_path = f"{model_key}.{quantize_mode}.onnx"

# Step 1: Quantize and export to ONNX
cmd_parts = extend_cmd_parts(
["python", "torch_quant_to_onnx.py"],
timm_model_name=timm_model_name,
Expand All @@ -73,8 +44,5 @@ def test_torch_onnx(model_key, quantize_mode):
calibration_data_size="1",
num_score_steps="1",
)
cmd_parts.append("--no_pretrained")
cmd_parts.extend(["--no_pretrained", "--trt_build"])
run_example_command(cmd_parts, "torch_onnx")

# Step 2: Verify the exported ONNX model builds a TensorRT engine
_verify_trt_engine_build(onnx_save_path, quantize_mode)
Loading