Skip to content

Commit bdc04f1

Browse files
authored
[6056809] Fix TRT dependency in ModelOpt ONNX quantization (#1189)
### What does this PR do? Type of change: Bug fix Regression bug introduced by the Autotune integration into ModelOpt ONNX quantization (#951), making ModelOpt dependent on TensorRT in all scenarios. This PR fixes this issue by requiring TensorRT only when `--autotune` is enabled. ### Usage ```python $ python -m modelopt.onnx.quantization --onnx_path=${MODEL_NAME}.onnx ``` ### Testing See bug 6056809. ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Autotune dependency failures now surface as clearer runtime errors instead of only logging warnings. * **Chores** * Centralized autotune presets and numeric defaults into a shared configuration. * Core autotune components are conditionally exposed so initialization succeeds when optional acceleration libraries are absent. * Deferred autotune imports to runtime to improve failure handling. * **Tests** * Added a test ensuring the quantization CLI/parser initializes correctly without optional acceleration libraries. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent 4255bc6 commit bdc04f1

6 files changed

Lines changed: 104 additions & 50 deletions

File tree

modelopt/onnx/quantization/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy as np
2222

23-
from modelopt.onnx.quantization.autotune import (
23+
from modelopt.onnx.quantization.autotune.utils import (
2424
MODE_PRESETS,
2525
StoreWithExplicitFlag,
2626
get_node_filter_list,

modelopt/onnx/quantization/autotune/__init__.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,42 @@
2020
region analysis to efficiently explore and optimize Q/DQ insertion strategies.
2121
"""
2222

23-
# Expose Autotune modes
24-
from .__main__ import MODE_PRESETS
23+
# Expose Autotune modes and CLI utilities
24+
from .utils import MODE_PRESETS, StoreWithExplicitFlag, get_node_filter_list
2525

26-
# Core data structures
27-
from .autotuner import QDQAutotuner
28-
from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark
29-
from .common import (
30-
AutotunerError,
31-
AutotunerNotInitializedError,
32-
Config,
33-
InsertionScheme,
34-
InvalidSchemeError,
35-
PatternCache,
36-
PatternSchemes,
37-
Region,
38-
RegionType,
39-
)
40-
from .insertion_points import (
41-
ChildRegionInputInsertionPoint,
42-
ChildRegionOutputInsertionPoint,
43-
NodeInputInsertionPoint,
44-
ResolvedInsertionPoint,
45-
)
46-
from .region_pattern import RegionPattern
47-
from .region_search import CombinedRegionSearch
48-
from .utils import StoreWithExplicitFlag, get_node_filter_list
26+
# Core data structures (requires TensorRT)
27+
try:
28+
from .autotuner import QDQAutotuner
29+
from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark
30+
from .common import (
31+
AutotunerError,
32+
AutotunerNotInitializedError,
33+
Config,
34+
InsertionScheme,
35+
InvalidSchemeError,
36+
PatternCache,
37+
PatternSchemes,
38+
Region,
39+
RegionType,
40+
)
41+
from .insertion_points import (
42+
ChildRegionInputInsertionPoint,
43+
ChildRegionOutputInsertionPoint,
44+
NodeInputInsertionPoint,
45+
ResolvedInsertionPoint,
46+
)
47+
from .region_pattern import RegionPattern
48+
from .region_search import CombinedRegionSearch
49+
except ImportError as e:
50+
from modelopt.onnx.logging_config import logger
4951

50-
__all__ = [
51-
"MODE_PRESETS",
52+
logger.warning(
53+
f"Failed to import Autotune dependencies: '{e}'. Ignore if Autotune is not being used."
54+
)
55+
56+
__all__ = ["MODE_PRESETS", "StoreWithExplicitFlag", "get_node_filter_list"]
57+
58+
_OPTIONAL_EXPORTS = [
5259
"AutotunerError",
5360
"AutotunerNotInitializedError",
5461
"ChildRegionInputInsertionPoint",
@@ -65,8 +72,7 @@
6572
"RegionPattern",
6673
"RegionType",
6774
"ResolvedInsertionPoint",
68-
"StoreWithExplicitFlag",
6975
"TensorRTPyBenchmark",
7076
"TrtExecBenchmark",
71-
"get_node_filter_list",
7277
]
78+
__all__.extend(name for name in _OPTIONAL_EXPORTS if name in globals())

modelopt/onnx/quantization/autotune/__main__.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
from modelopt.onnx.logging_config import logger
2424
from modelopt.onnx.quantization.autotune.utils import (
25+
DEFAULT_NUM_SCHEMES,
26+
DEFAULT_TIMING_RUNS,
27+
DEFAULT_WARMUP_RUNS,
28+
MODE_PRESETS,
2529
StoreWithExplicitFlag,
2630
get_node_filter_list,
2731
validate_file_path,
@@ -32,21 +36,9 @@
3236
)
3337

3438
DEFAULT_OUTPUT_DIR = "./autotuner_output"
35-
DEFAULT_NUM_SCHEMES = 50
3639
DEFAULT_QUANT_TYPE = "int8"
3740
DEFAULT_DQ_DTYPE = "float32"
3841
DEFAULT_TIMING_CACHE = str(Path(tempfile.gettempdir()) / "trtexec_timing.cache")
39-
DEFAULT_WARMUP_RUNS = 50
40-
DEFAULT_TIMING_RUNS = 100
41-
MODE_PRESETS = {
42-
"quick": {"schemes_per_region": 30, "warmup_runs": 10, "timing_runs": 50},
43-
"default": {
44-
"schemes_per_region": DEFAULT_NUM_SCHEMES,
45-
"warmup_runs": DEFAULT_WARMUP_RUNS,
46-
"timing_runs": DEFAULT_TIMING_RUNS,
47-
},
48-
"extensive": {"schemes_per_region": 200, "warmup_runs": 50, "timing_runs": 200},
49-
}
5042

5143

5244
def apply_mode_presets(args) -> None:

modelopt/onnx/quantization/autotune/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@
2121

2222
from modelopt.onnx.logging_config import logger
2323

24+
DEFAULT_NUM_SCHEMES = 50
25+
DEFAULT_WARMUP_RUNS = 50
26+
DEFAULT_TIMING_RUNS = 100
27+
28+
MODE_PRESETS = {
29+
"quick": {"schemes_per_region": 30, "warmup_runs": 10, "timing_runs": 50},
30+
"default": {
31+
"schemes_per_region": DEFAULT_NUM_SCHEMES,
32+
"warmup_runs": DEFAULT_WARMUP_RUNS,
33+
"timing_runs": DEFAULT_TIMING_RUNS,
34+
},
35+
"extensive": {"schemes_per_region": 200, "warmup_runs": 50, "timing_runs": 200},
36+
}
37+
2438

2539
class StoreWithExplicitFlag(argparse.Action):
2640
"""Store the value and set an 'explicit' flag on the namespace so mode presets do not override."""

modelopt/onnx/quantization/quantize.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,6 @@
4747

4848
from modelopt.onnx.logging_config import configure_logging, logger
4949
from modelopt.onnx.op_types import is_data_dependent_shape_op
50-
51-
try:
52-
from modelopt.onnx.quantization.autotune.workflows import (
53-
init_benchmark_instance,
54-
region_pattern_autotuning_workflow,
55-
)
56-
except ImportError:
57-
logger.warning("Failed to import Autotune dependencies")
5850
from modelopt.onnx.quantization.calib_utils import (
5951
CalibrationDataProvider,
6052
CalibrationDataType,
@@ -287,6 +279,17 @@ def _find_nodes_to_quantize_autotune(
287279
"""Extracts quantization information from Autotune to provide ORT quantization."""
288280
logger.info("Running Auto Q/DQ with TensorRT")
289281

282+
try:
283+
from modelopt.onnx.quantization.autotune.workflows import (
284+
init_benchmark_instance,
285+
region_pattern_autotuning_workflow,
286+
)
287+
except ImportError as e:
288+
raise RuntimeError(
289+
f"Failed to import Autotune dependencies: '{e}'."
290+
"Make sure that all Autotune requirements are installed (i.e., TensorRT)."
291+
)
292+
290293
benchmark_instance = init_benchmark_instance(
291294
use_trtexec=use_trtexec,
292295
plugin_libraries=trt_plugins,
@@ -295,6 +298,7 @@ def _find_nodes_to_quantize_autotune(
295298
timing_runs=timing_runs,
296299
trtexec_args=trtexec_args.split() if trtexec_args else None,
297300
)
301+
298302
if benchmark_instance is None:
299303
raise RuntimeError("Failed to initialize TensorRT benchmark")
300304

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import importlib
17+
import sys
18+
19+
import pytest
20+
21+
22+
def test_quantization_cli_parser_imports_without_tensorrt():
23+
"""Verify the CLI parser can be constructed without TensorRT installed."""
24+
with pytest.MonkeyPatch.context() as mp:
25+
# Force tensorrt import to fail, even if it's actually installed
26+
mp.setitem(sys.modules, "tensorrt", None)
27+
28+
# Reload the autotune package so it picks up the blocked import
29+
import modelopt.onnx.quantization.autotune
30+
31+
importlib.reload(modelopt.onnx.quantization.autotune)
32+
33+
from modelopt.onnx.quantization.__main__ import get_parser
34+
35+
parser = get_parser()
36+
args = parser.parse_args(["--onnx_path", "dummy.onnx"])
37+
assert args.onnx_path == "dummy.onnx"
38+
assert args.quantize_mode == "int8"

0 commit comments

Comments
 (0)