Skip to content

Commit fc4d541

Browse files
committed
address cr
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 882e3a7 commit fc4d541

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

modelopt/onnx/autocast/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def convert_to_mixed_precision(
8585
trt_plugins_precision: List indicating the precision for each custom op.
8686
max_depth_of_reduction: Maximum depth of reduction for node classification.
8787
opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type
88-
(22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations
88+
(22 for bf16, 19 for fp16). The opset may be automatically increased if certain operations
8989
require a higher version.
9090
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
9191
infer_shapes. This is a workaround (WAR) when only type inference is

modelopt/onnx/quantization/quantize.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@
8080
__all__ = ["quantize"]
8181

8282

83+
def _normalize_quantize_mode_for_opset(quantize_mode: str) -> str:
84+
"""Map variants like "int4_awq", "int4_rtn", "nvfp4" to their base precision types for lookup purposes."""
85+
mode_lower = quantize_mode.lower()
86+
if "int4" in mode_lower:
87+
return "int4"
88+
if "nvfp4" in mode_lower or "float4" in mode_lower:
89+
return "float4_e2m1fn"
90+
# For "int8", "fp8", etc., return as-is (fp8 falls back to BASE_MIN_OPSET which is correct)
91+
return quantize_mode
92+
93+
8394
def _preprocess_onnx(
8495
onnx_path: str,
8596
use_external_data_format: bool,
@@ -126,7 +137,9 @@ def _preprocess_onnx(
126137
original_opset_version = get_opset_version(onnx_model)
127138

128139
# Determine minimum required opset based on quantization mode
129-
mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(quantize_mode, BASE_MIN_OPSET)
140+
# Normalize quantize_mode to handle variants like "int4_awq", "nvfp4", etc.
141+
normalized_mode = _normalize_quantize_mode_for_opset(quantize_mode)
142+
mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(normalized_mode, BASE_MIN_OPSET)
130143

131144
# Determine target opset version
132145
if opset is not None:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
"onnx-graphsurgeon",
4949
"onnx~=1.19.0",
5050
"onnxconverter-common~=1.16.0",
51-
"onnxruntime~=1.23.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
52-
"onnxruntime-gpu~=1.23.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin'",
51+
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
52+
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin'",
5353
"onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test
5454
"onnxslim>=0.1.76",
5555
"polygraphy>=0.49.22",

0 commit comments

Comments
 (0)