Skip to content

Commit e5de5ec

Browse files
gcunhasekevalmorabia97
authored andcommitted
[6056809] Fix TRT dependency in ModelOpt ONNX quantization (#1189)
### What does this PR do? Type of change: Bug fix Regression bug introduced by the Autotune integration into ModelOpt ONNX quantization (#951), making ModelOpt dependent on TensorRT in all scenarios. This PR fixes this issue by requiring TensorRT only when `--autotune` is enabled. ### Usage ```python $ python -m modelopt.onnx.quantization --onnx_path=${MODEL_NAME}.onnx ``` ### Testing See bug 6056809. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Autotune dependency failures now surface as clearer runtime errors instead of only logging warnings. * **Chores** * Centralized autotune presets and numeric defaults into a shared configuration. * Core autotune components are conditionally exposed so initialization succeeds when optional acceleration libraries are absent. * Deferred autotune imports to runtime to improve failure handling. * **Tests** * Added a test ensuring the quantization CLI/parser initializes correctly without optional acceleration libraries. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent f3151d2 commit e5de5ec

File tree

6 files changed

+104
-50
lines changed

6 files changed

+104
-50
lines changed

modelopt/onnx/quantization/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy as np
2222

23-
from modelopt.onnx.quantization.autotune import (
23+
from modelopt.onnx.quantization.autotune.utils import (
2424
MODE_PRESETS,
2525
StoreWithExplicitFlag,
2626
get_node_filter_list,

modelopt/onnx/quantization/autotune/__init__.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,42 @@
2020
region analysis to efficiently explore and optimize Q/DQ insertion strategies.
2121
"""
2222

23-
# Expose Autotune modes
24-
from .__main__ import MODE_PRESETS
23+
# Expose Autotune modes and CLI utilities
24+
from .utils import MODE_PRESETS, StoreWithExplicitFlag, get_node_filter_list
2525

26-
# Core data structures
27-
from .autotuner import QDQAutotuner
28-
from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark
29-
from .common import (
30-
AutotunerError,
31-
AutotunerNotInitializedError,
32-
Config,
33-
InsertionScheme,
34-
InvalidSchemeError,
35-
PatternCache,
36-
PatternSchemes,
37-
Region,
38-
RegionType,
39-
)
40-
from .insertion_points import (
41-
ChildRegionInputInsertionPoint,
42-
ChildRegionOutputInsertionPoint,
43-
NodeInputInsertionPoint,
44-
ResolvedInsertionPoint,
45-
)
46-
from .region_pattern import RegionPattern
47-
from .region_search import CombinedRegionSearch
48-
from .utils import StoreWithExplicitFlag, get_node_filter_list
26+
# Core data structures (requires TensorRT)
27+
try:
28+
from .autotuner import QDQAutotuner
29+
from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark
30+
from .common import (
31+
AutotunerError,
32+
AutotunerNotInitializedError,
33+
Config,
34+
InsertionScheme,
35+
InvalidSchemeError,
36+
PatternCache,
37+
PatternSchemes,
38+
Region,
39+
RegionType,
40+
)
41+
from .insertion_points import (
42+
ChildRegionInputInsertionPoint,
43+
ChildRegionOutputInsertionPoint,
44+
NodeInputInsertionPoint,
45+
ResolvedInsertionPoint,
46+
)
47+
from .region_pattern import RegionPattern
48+
from .region_search import CombinedRegionSearch
49+
except ImportError as e:
50+
from modelopt.onnx.logging_config import logger
4951

50-
__all__ = [
51-
"MODE_PRESETS",
52+
logger.warning(
53+
f"Failed to import Autotune dependencies: '{e}'. Ignore if Autotune is not being used."
54+
)
55+
56+
__all__ = ["MODE_PRESETS", "StoreWithExplicitFlag", "get_node_filter_list"]
57+
58+
_OPTIONAL_EXPORTS = [
5259
"AutotunerError",
5360
"AutotunerNotInitializedError",
5461
"ChildRegionInputInsertionPoint",
@@ -65,8 +72,7 @@
6572
"RegionPattern",
6673
"RegionType",
6774
"ResolvedInsertionPoint",
68-
"StoreWithExplicitFlag",
6975
"TensorRTPyBenchmark",
7076
"TrtExecBenchmark",
71-
"get_node_filter_list",
7277
]
78+
__all__.extend(name for name in _OPTIONAL_EXPORTS if name in globals())

modelopt/onnx/quantization/autotune/__main__.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
from modelopt.onnx.logging_config import logger
2424
from modelopt.onnx.quantization.autotune.utils import (
25+
DEFAULT_NUM_SCHEMES,
26+
DEFAULT_TIMING_RUNS,
27+
DEFAULT_WARMUP_RUNS,
28+
MODE_PRESETS,
2529
StoreWithExplicitFlag,
2630
get_node_filter_list,
2731
validate_file_path,
@@ -32,21 +36,9 @@
3236
)
3337

3438
DEFAULT_OUTPUT_DIR = "./autotuner_output"
35-
DEFAULT_NUM_SCHEMES = 50
3639
DEFAULT_QUANT_TYPE = "int8"
3740
DEFAULT_DQ_DTYPE = "float32"
3841
DEFAULT_TIMING_CACHE = str(Path(tempfile.gettempdir()) / "trtexec_timing.cache")
39-
DEFAULT_WARMUP_RUNS = 50
40-
DEFAULT_TIMING_RUNS = 100
41-
MODE_PRESETS = {
42-
"quick": {"schemes_per_region": 30, "warmup_runs": 10, "timing_runs": 50},
43-
"default": {
44-
"schemes_per_region": DEFAULT_NUM_SCHEMES,
45-
"warmup_runs": DEFAULT_WARMUP_RUNS,
46-
"timing_runs": DEFAULT_TIMING_RUNS,
47-
},
48-
"extensive": {"schemes_per_region": 200, "warmup_runs": 50, "timing_runs": 200},
49-
}
5042

5143

5244
def apply_mode_presets(args) -> None:

modelopt/onnx/quantization/autotune/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@
2121

2222
from modelopt.onnx.logging_config import logger
2323

24+
DEFAULT_NUM_SCHEMES = 50
25+
DEFAULT_WARMUP_RUNS = 50
26+
DEFAULT_TIMING_RUNS = 100
27+
28+
MODE_PRESETS = {
29+
"quick": {"schemes_per_region": 30, "warmup_runs": 10, "timing_runs": 50},
30+
"default": {
31+
"schemes_per_region": DEFAULT_NUM_SCHEMES,
32+
"warmup_runs": DEFAULT_WARMUP_RUNS,
33+
"timing_runs": DEFAULT_TIMING_RUNS,
34+
},
35+
"extensive": {"schemes_per_region": 200, "warmup_runs": 50, "timing_runs": 200},
36+
}
37+
2438

2539
class StoreWithExplicitFlag(argparse.Action):
2640
"""Store the value and set an 'explicit' flag on the namespace so mode presets do not override."""

modelopt/onnx/quantization/quantize.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,6 @@
4646

4747
from modelopt.onnx.logging_config import configure_logging, logger
4848
from modelopt.onnx.op_types import is_data_dependent_shape_op
49-
50-
try:
51-
from modelopt.onnx.quantization.autotune.workflows import (
52-
init_benchmark_instance,
53-
region_pattern_autotuning_workflow,
54-
)
55-
except ImportError:
56-
logger.warning("Failed to import Autotune dependencies")
5749
from modelopt.onnx.quantization.calib_utils import (
5850
CalibrationDataProvider,
5951
CalibrationDataType,
@@ -286,6 +278,17 @@ def _find_nodes_to_quantize_autotune(
286278
"""Extracts quantization information from Autotune to provide ORT quantization."""
287279
logger.info("Running Auto Q/DQ with TensorRT")
288280

281+
try:
282+
from modelopt.onnx.quantization.autotune.workflows import (
283+
init_benchmark_instance,
284+
region_pattern_autotuning_workflow,
285+
)
286+
except ImportError as e:
287+
raise RuntimeError(
288+
f"Failed to import Autotune dependencies: '{e}'."
289+
"Make sure that all Autotune requirements are installed (i.e., TensorRT)."
290+
)
291+
289292
benchmark_instance = init_benchmark_instance(
290293
use_trtexec=use_trtexec,
291294
plugin_libraries=trt_plugins,
@@ -294,6 +297,7 @@ def _find_nodes_to_quantize_autotune(
294297
timing_runs=timing_runs,
295298
trtexec_args=trtexec_args.split() if trtexec_args else None,
296299
)
300+
297301
if benchmark_instance is None:
298302
raise RuntimeError("Failed to initialize TensorRT benchmark")
299303

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import importlib
17+
import sys
18+
19+
import pytest
20+
21+
22+
def test_quantization_cli_parser_imports_without_tensorrt():
23+
"""Verify the CLI parser can be constructed without TensorRT installed."""
24+
with pytest.MonkeyPatch.context() as mp:
25+
# Force tensorrt import to fail, even if it's actually installed
26+
mp.setitem(sys.modules, "tensorrt", None)
27+
28+
# Reload the autotune package so it picks up the blocked import
29+
import modelopt.onnx.quantization.autotune
30+
31+
importlib.reload(modelopt.onnx.quantization.autotune)
32+
33+
from modelopt.onnx.quantization.__main__ import get_parser
34+
35+
parser = get_parser()
36+
args = parser.parse_args(["--onnx_path", "dummy.onnx"])
37+
assert args.onnx_path == "dummy.onnx"
38+
assert args.quantize_mode == "int8"

0 commit comments

Comments
 (0)