|
20 | 20 | region analysis to efficiently explore and optimize Q/DQ insertion strategies. |
21 | 21 | """ |
22 | 22 |
|
23 | | -# Expose Autotune modes |
24 | | -from .__main__ import MODE_PRESETS |
| 23 | +# Expose Autotune modes and CLI utilities |
| 24 | +from .utils import MODE_PRESETS, StoreWithExplicitFlag, get_node_filter_list |
25 | 25 |
|
26 | | -# Core data structures |
27 | | -from .autotuner import QDQAutotuner |
28 | | -from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark |
29 | | -from .common import ( |
30 | | - AutotunerError, |
31 | | - AutotunerNotInitializedError, |
32 | | - Config, |
33 | | - InsertionScheme, |
34 | | - InvalidSchemeError, |
35 | | - PatternCache, |
36 | | - PatternSchemes, |
37 | | - Region, |
38 | | - RegionType, |
39 | | -) |
40 | | -from .insertion_points import ( |
41 | | - ChildRegionInputInsertionPoint, |
42 | | - ChildRegionOutputInsertionPoint, |
43 | | - NodeInputInsertionPoint, |
44 | | - ResolvedInsertionPoint, |
45 | | -) |
46 | | -from .region_pattern import RegionPattern |
47 | | -from .region_search import CombinedRegionSearch |
48 | | -from .utils import StoreWithExplicitFlag, get_node_filter_list |
| 26 | +# Core data structures (requires TensorRT) |
| 27 | +try: |
| 28 | + from .autotuner import QDQAutotuner |
| 29 | + from .benchmark import TensorRTPyBenchmark, TrtExecBenchmark |
| 30 | + from .common import ( |
| 31 | + AutotunerError, |
| 32 | + AutotunerNotInitializedError, |
| 33 | + Config, |
| 34 | + InsertionScheme, |
| 35 | + InvalidSchemeError, |
| 36 | + PatternCache, |
| 37 | + PatternSchemes, |
| 38 | + Region, |
| 39 | + RegionType, |
| 40 | + ) |
| 41 | + from .insertion_points import ( |
| 42 | + ChildRegionInputInsertionPoint, |
| 43 | + ChildRegionOutputInsertionPoint, |
| 44 | + NodeInputInsertionPoint, |
| 45 | + ResolvedInsertionPoint, |
| 46 | + ) |
| 47 | + from .region_pattern import RegionPattern |
| 48 | + from .region_search import CombinedRegionSearch |
| 49 | +except ImportError as e: |
| 50 | + from modelopt.onnx.logging_config import logger |
49 | 51 |
|
50 | | -__all__ = [ |
51 | | - "MODE_PRESETS", |
| 52 | + logger.warning( |
| 53 | + f"Failed to import Autotune dependencies: '{e}'. Ignore if Autotune is not being used." |
| 54 | + ) |
| 55 | + |
| 56 | +__all__ = ["MODE_PRESETS", "StoreWithExplicitFlag", "get_node_filter_list"] |
| 57 | + |
| 58 | +_OPTIONAL_EXPORTS = [ |
52 | 59 | "AutotunerError", |
53 | 60 | "AutotunerNotInitializedError", |
54 | 61 | "ChildRegionInputInsertionPoint", |
|
65 | 72 | "RegionPattern", |
66 | 73 | "RegionType", |
67 | 74 | "ResolvedInsertionPoint", |
68 | | - "StoreWithExplicitFlag", |
69 | 75 | "TensorRTPyBenchmark", |
70 | 76 | "TrtExecBenchmark", |
71 | | - "get_node_filter_list", |
72 | 77 | ] |
| 78 | +__all__.extend(name for name in _OPTIONAL_EXPORTS if name in globals()) |
0 commit comments