Skip to content

Commit e55992c

Browse files
galagamgcunhase
authored andcommitted
[5525939] Allow user to select target opset in MOQ (#809)
## What does this PR do? **Type of change:** new feature **Overview:** - Allow user to select the target opset - Minimum opset will be defined according to quantization mode - Add tests in tests/unit/onnx/test_quantize_api.py ## Testing Added unit tests tests/unit/onnx/test_quantize_api.py::test_opset_below_minimum_upgrades_to_minimum[int8] PASSED [ 11%] tests/unit/onnx/test_quantize_api.py::test_opset_below_minimum_upgrades_to_minimum[fp8] PASSED [ 22%] tests/unit/onnx/test_quantize_api.py::test_opset_below_minimum_upgrades_to_minimum[int4] PASSED [ 33%] tests/unit/onnx/test_quantize_api.py::test_opset_below_original_uses_original[int8] PASSED [ 44%] tests/unit/onnx/test_quantize_api.py::test_opset_below_original_uses_original[fp8] PASSED [ 55%] tests/unit/onnx/test_quantize_api.py::test_opset_below_original_uses_original[int4] PASSED [ 66%] tests/unit/onnx/test_quantize_api.py::test_opset_above_minimum[int8] PASSED [ 77%] tests/unit/onnx/test_quantize_api.py::test_opset_above_minimum[fp8] PASSED [ 88%] tests/unit/onnx/test_quantize_api.py::test_opset_above_minimum[int4] PASSED [100%] ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - auto update according to argparser help - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information Requested as a WAR for a Windows-onnxruntime issue in 5525939, but regardless, it's a useful feature to have <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--opset` CLI option enabling users to specify target ONNX opset version when quantizing models. * Automatic validation ensures the opset version is compatible with quantization requirements, with warnings when adjustments are made. * **Tests** * Added comprehensive test coverage for opset version handling across quantization workflows. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com> Signed-off-by: Gal Hubara-Agam <96368689+galagam@users.noreply.github.com> Co-authored-by: Gwena Cunha <4861122+gcunhase@users.noreply.github.com>
1 parent 241f5b7 commit e55992c

File tree

8 files changed

+279
-17
lines changed

8 files changed

+279
-17
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ NVIDIA Model Optimizer Changelog (Linux)
1414
- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint.
1515
- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details on its usage.
1616
- Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow.
17+
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.
1718

1819
0.41 (2026-01-19)
1920
^^^^^^^^^^^^^^^^^

modelopt/onnx/autocast/convert.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from modelopt.onnx.autocast.nodeclassifier import NodeClassifier, NodeRuleBase
3434
from modelopt.onnx.autocast.precisionconverter import PrecisionConverter
3535
from modelopt.onnx.autocast.referencerunner import ReferenceRunner
36+
from modelopt.onnx.utils import get_min_opset_for_precisions, get_qdq_precisions
3637

3738
"""
3839
FP16 accuracy decreases in accordance with the data's magnitude.
@@ -84,7 +85,7 @@ def convert_to_mixed_precision(
8485
trt_plugins_precision: List indicating the precision for each custom op.
8586
max_depth_of_reduction: Maximum depth of reduction for node classification.
8687
opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type
87-
(22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations
88+
(22 for bf16, 19 for fp16). The opset may be automatically increased if certain operations
8889
require a higher version.
8990
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
9091
infer_shapes. This is a workaround (WAR) when only type inference is
@@ -202,6 +203,7 @@ def convert_to_f16(
202203
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
203204
trt_plugins: list[str] | None = [],
204205
use_standalone_type_inference: bool = False,
206+
opset: int | None = None,
205207
) -> onnx.ModelProto:
206208
"""Convert model to mixed precision, using PrecisionConverter.
207209
@@ -217,13 +219,45 @@ def convert_to_f16(
217219
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
218220
infer_shapes. This is a workaround (WAR) when only type inference is
219221
needed without shape inference. Default: False.
222+
opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
223+
(22 for bf16, 19 for fp16) and Q/DQ node requirements. The opset may be automatically
224+
increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19,
225+
INT4 requires 21, NVFP4 requires 23).
220226
"""
221227
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
222228

223-
# Opset 21 is needed for NVFP4 quantization support (DQ with 'block_size' attribute)
229+
# Check Q/DQ precision types in the model and determine required opset
230+
qdq_precisions = get_qdq_precisions(model)
231+
qdq_min_opset = get_min_opset_for_precisions(qdq_precisions)
232+
233+
# Base minimum opset for FP16/BF16 conversion
234+
# Opset 19 is the first to support fp16 scales in Q/DQ nodes
235+
base_min_opset = 22 if low_precision_type == "bf16" else 19
236+
237+
# Determine target opset version
238+
if opset is not None:
239+
min_opset = opset
240+
# Check if Q/DQ nodes require a higher opset
241+
if qdq_precisions and qdq_min_opset > min_opset:
242+
logger.warning(
243+
f"Model contains Q/DQ nodes with precisions {qdq_precisions} that require "
244+
f"opset >= {qdq_min_opset}. Upgrading from specified opset {opset} to {qdq_min_opset}."
245+
)
246+
min_opset = qdq_min_opset
247+
# Also ensure we meet base minimum for precision type
248+
if min_opset < base_min_opset:
249+
logger.warning(
250+
f"Opset {min_opset} is below minimum opset {base_min_opset} for {low_precision_type}. "
251+
f"Upgrading to opset {base_min_opset}."
252+
)
253+
min_opset = base_min_opset
254+
else:
255+
# Use the highest required opset between base and Q/DQ requirements
256+
min_opset = max(base_min_opset, qdq_min_opset)
257+
224258
sanitizer = GraphSanitizer(
225259
model,
226-
min_opset=21,
260+
min_opset=min_opset,
227261
trt_plugins=trt_plugins,
228262
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
229263
)

modelopt/onnx/quantization/__main__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,15 @@ def get_parser() -> argparse.ArgumentParser:
286286
"The currently supported precisions are {fp16, int8, fp8}."
287287
),
288288
)
289+
argparser.add_argument(
290+
"--opset",
291+
type=int,
292+
help=(
293+
"Target ONNX opset version for the quantized model. If not specified, uses default minimum opset "
294+
"(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased "
295+
"if certain operations require a higher version."
296+
),
297+
)
289298
return argparser
290299

291300

@@ -352,6 +361,7 @@ def main():
352361
simplify=args.simplify,
353362
calibrate_per_node=args.calibrate_per_node,
354363
direct_io_types=args.direct_io_types,
364+
opset=args.opset,
355365
)
356366

357367

modelopt/onnx/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def quantize(
182182
calibrate_per_node: bool = False,
183183
custom_ops_to_quantize: list[str] = [],
184184
direct_io_types: bool = False,
185+
opset: int | None = None,
185186
**kwargs,
186187
) -> onnx.ModelProto:
187188
"""Applies FP8 GEMM only quantization to an ONNX file.
@@ -328,6 +329,7 @@ def quantize(
328329
tensor_block_dict=custom_ops_to_cast_fp32 or {},
329330
low_precision_type=high_precision_dtype,
330331
trt_plugins=trt_extra_plugin_lib_paths,
332+
opset=opset,
331333
)
332334

333335
current_opsets = {opset.domain: opset.version for opset in onnx_model.opset_import}

modelopt/onnx/quantization/int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def quantize(
132132
calibrate_per_node: bool = False,
133133
custom_ops_to_quantize: list[str] = [],
134134
direct_io_types: bool = False,
135+
opset: int | None = None,
135136
**kwargs,
136137
) -> onnx.ModelProto:
137138
"""Applies INT8 quantization to an ONNX file using the compiler friendly heuristics.
@@ -289,6 +290,7 @@ def quantize(
289290
tensor_block_dict=custom_ops_to_cast_fp32 or {},
290291
low_precision_type=high_precision_dtype,
291292
trt_plugins=trt_extra_plugin_lib_paths,
293+
opset=opset,
292294
)
293295

294296
if nodes_to_quantize:

modelopt/onnx/quantization/quantize.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
)
7070
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
7171
from modelopt.onnx.utils import (
72+
BASE_MIN_OPSET,
73+
QDQ_PRECISION_MIN_OPSET,
7274
duplicate_shared_constants,
7375
get_opset_version,
7476
name_onnx_nodes,
@@ -78,6 +80,17 @@
7880
__all__ = ["quantize"]
7981

8082

83+
def _normalize_quantize_mode_for_opset(quantize_mode: str) -> str:
84+
"""Map variants like "int4_awq", "int4_rtn", "nvfp4" to their base precision types for lookup purposes."""
85+
mode_lower = quantize_mode.lower()
86+
if "int4" in mode_lower:
87+
return "int4"
88+
if "nvfp4" in mode_lower or "float4" in mode_lower:
89+
return "float4_e2m1fn"
90+
# For "int8", "fp8", etc., return as-is (fp8 falls back to BASE_MIN_OPSET which is correct)
91+
return quantize_mode
92+
93+
8194
def _preprocess_onnx(
8295
onnx_path: str,
8396
use_external_data_format: bool,
@@ -88,6 +101,7 @@ def _preprocess_onnx(
88101
override_shapes: str,
89102
simplify: bool = False,
90103
quantize_mode: str = "int8",
104+
opset: int | None = None,
91105
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
92106
logger.info(f"Preprocessing the model {onnx_path}")
93107
intermediate_generated_files = []
@@ -118,16 +132,45 @@ def _preprocess_onnx(
118132
" '--trt_plugins' flag (requires TRT 10+)."
119133
)
120134

121-
# Per-Channel support with QDQ format requires onnx opset version 13 or above
122-
opset_version = get_opset_version(onnx_model)
135+
# Opset 19 is the minimum required for fp16 scales in Q/DQ nodes
136+
# Higher opsets required for specific quantization modes (int4: 21, nvfp4: 23)
137+
original_opset_version = get_opset_version(onnx_model)
138+
139+
# Determine minimum required opset based on quantization mode
140+
# Normalize quantize_mode to handle variants like "int4_awq", "nvfp4", etc.
141+
normalized_mode = _normalize_quantize_mode_for_opset(quantize_mode)
142+
mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(normalized_mode, BASE_MIN_OPSET)
143+
144+
# Determine target opset version
145+
if opset is not None:
146+
target_opset = opset
147+
# Warn if user-specified opset is below mode minimum (but still respect it)
148+
if opset < mode_min_opset:
149+
logger.warning(
150+
f"Opset {opset} is below the minimum opset {mode_min_opset} required for "
151+
f"{quantize_mode} quantization. Upgrading to opset {mode_min_opset}."
152+
)
153+
target_opset = mode_min_opset
154+
# Warn if user-specified opset is lower than original
155+
if opset < original_opset_version:
156+
logger.warning(
157+
f"Specified opset {opset} is lower than the original model's opset {original_opset_version}. "
158+
f"Using original model's opset {original_opset_version}."
159+
)
160+
target_opset = max(target_opset, original_opset_version)
161+
else:
162+
# Use model's opset if it's >= mode_min_opset, otherwise upgrade to mode_min_opset
163+
target_opset = (
164+
max(original_opset_version, mode_min_opset)
165+
if original_opset_version != 1
166+
else mode_min_opset
167+
)
123168

124-
required_opset_version = 13
125-
if opset_version < required_opset_version and opset_version != 1:
126-
opset_version = required_opset_version
127-
onnx_model = onnx.version_converter.convert_version(onnx_model, opset_version)
128-
onnx_path = os.path.join(output_dir, f"{model_name}_opset{opset_version}.onnx")
169+
if original_opset_version < target_opset and original_opset_version != 1:
170+
onnx_model = onnx.version_converter.convert_version(onnx_model, target_opset)
171+
onnx_path = os.path.join(output_dir, f"{model_name}_opset{target_opset}.onnx")
129172
save_onnx(onnx_model, onnx_path, use_external_data_format)
130-
logger.info(f"Model is cloned to {onnx_path} with opset_version {opset_version}")
173+
logger.info(f"Model is cloned to {onnx_path} with opset_version {target_opset}")
131174
intermediate_generated_files.append(onnx_path)
132175

133176
# Simplify model if requested
@@ -231,6 +274,7 @@ def quantize(
231274
calibrate_per_node: bool = False,
232275
input_shapes_profile: Sequence[dict[str, str]] | None = None,
233276
direct_io_types: bool = False,
277+
opset: int | None = None,
234278
**kwargs: Any,
235279
) -> None:
236280
"""Quantizes the provided ONNX model.
@@ -350,6 +394,10 @@ def quantize(
350394
direct_io_types:
351395
If True, modify the I/O types in the quantized ONNX model to be lower precision whenever possible.
352396
If False, keep the I/O types in the quantized ONNX model the same as in the given ONNX model.
397+
opset:
398+
Target ONNX opset version for the quantized model. If None, uses required minimum opset
399+
(19 for int8/fp8, 21 for int4, 23 for nvfp4). If the specified opset is lower than the required minimum,
400+
a warning will be issued and the opset will be upgraded to the required minimum.
353401
kwargs:
354402
Additional keyword arguments for int4 quantization, including:
355403
- awqlite_alpha_step (float): Alpha step for lite, range [0, 1].
@@ -420,6 +468,7 @@ def quantize(
420468
override_shapes, # type: ignore[arg-type]
421469
simplify,
422470
quantize_mode,
471+
opset,
423472
)
424473
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
425474

@@ -481,6 +530,7 @@ def quantize(
481530
calibrate_per_node=calibrate_per_node,
482531
custom_ops_to_quantize=list(custom_ops_to_quantize.keys()),
483532
direct_io_types=direct_io_types,
533+
opset=opset,
484534
**kwargs,
485535
)
486536
elif "int4" in quantize_mode:

modelopt/onnx/utils.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131

3232
from modelopt.onnx.logging_config import logger
3333

34+
# Base minimum opset for quantization (opset 19 is the first to support fp16 scales)
35+
BASE_MIN_OPSET = 19
36+
3437

3538
def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]:
3639
"""This function returns the inputs names of the given onnx model in bytes.
@@ -697,18 +700,72 @@ def get_opset_version(model: onnx.ModelProto) -> int:
697700

698701

699702
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
700-
"""Checks if the model uses external data.
703+
"""Checks if the model uses external data. True if any initializer tensor has data_location set to EXTERNAL."""
704+
return any(
705+
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
706+
for init in model.graph.initializer
707+
)
708+
709+
710+
def get_qdq_precisions(model: onnx.ModelProto) -> set:
711+
"""Gets the Q/DQ precision types present in the model.
701712
702713
Args:
703714
model: Loaded in-memory onnx ModelProto.
704715
705716
Returns:
706-
True if any initializer tensor has data_location set to EXTERNAL.
717+
set: Set of Q/DQ precision types present in the model (e.g., 'float8_e4m3fn', 'int8',
718+
'int4', 'float4_e2m1fn').
707719
"""
708-
return any(
709-
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
710-
for init in model.graph.initializer
711-
)
720+
graph = gs.import_onnx(model)
721+
precisions = set()
722+
723+
# Check for custom 'NVFP4' nodes
724+
custom_fp4_q_nodes = [node for node in graph.nodes if node.op == "TRT_FP4DynamicQuantize"]
725+
if custom_fp4_q_nodes:
726+
precisions.add("float4_e2m1fn")
727+
728+
# Check for precision in DQ nodes
729+
dq_nodes = [node for node in graph.nodes if node.op == "DequantizeLinear"]
730+
for dq_node in dq_nodes:
731+
if len(dq_node.inputs) >= 3 and dq_node.inputs[2] is not None:
732+
# If zero-point is set, return that as the quantization mode
733+
if isinstance(dq_node.inputs[2], Constant) and dq_node.inputs[2].values is not None:
734+
precisions.add(dq_node.inputs[2].values.dtype.name)
735+
elif isinstance(dq_node.inputs[0], Constant) and dq_node.inputs[0].values is not None:
736+
# Else, return the node's input precision (ex: 'NVFP4' weight quantization)
737+
precisions.add(dq_node.inputs[0].values.dtype.name)
738+
739+
return precisions
740+
741+
742+
# Minimum opset requirements by quantization mode/precision
743+
# Base minimum is 19 (first opset that allows fp16 scales in Q/DQ nodes)
744+
# Supports both quantize modes (e.g., "fp8") and dtype prefixes (e.g., "float8" for "float8_e4m3fn")
745+
QDQ_PRECISION_MIN_OPSET = {
746+
"int8": BASE_MIN_OPSET,
747+
"float8_e4m3fn": BASE_MIN_OPSET,
748+
"int4": 21,
749+
"uint4": 21,
750+
"float4_e2m1fn": 23,
751+
}
752+
753+
754+
def get_min_opset_for_precisions(precisions: set) -> int:
755+
"""Gets the minimum required opset version for a set of Q/DQ precision types.
756+
757+
Args:
758+
precisions: Set of precision type strings (e.g., 'float8_e4m3fn', 'int4').
759+
760+
Returns:
761+
int: Minimum required opset version for the given precisions.
762+
"""
763+
min_opset = BASE_MIN_OPSET # Base minimum for fp16 scales support
764+
for precision in precisions:
765+
# Direct lookup first
766+
if precision in QDQ_PRECISION_MIN_OPSET:
767+
min_opset = max(min_opset, QDQ_PRECISION_MIN_OPSET[precision])
768+
return min_opset
712769

713770

714771
def bfloat16_to_float32(bf16_array):

0 commit comments

Comments
 (0)