Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/autotune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from .region_pattern import RegionPattern
from .region_search import CombinedRegionSearch
from .torch_region_builder import TorchRegionBuilder
except ImportError as e:
from modelopt.onnx.logging_config import logger

Expand Down Expand Up @@ -73,6 +74,7 @@
"RegionType",
"ResolvedInsertionPoint",
"TensorRTPyBenchmark",
"TorchRegionBuilder",
"TrtExecBenchmark",
]
__all__.extend(name for name in _OPTIONAL_EXPORTS if name in globals())
23 changes: 16 additions & 7 deletions modelopt/onnx/quantization/autotune/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from modelopt.onnx.quantization.autotune.autotuner_base import QDQAutotunerBase
from modelopt.onnx.quantization.autotune.common import Config, PatternCache, Region, RegionType
from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch
from modelopt.onnx.quantization.autotune.torch_region_builder import (
TorchRegionBuilder,
check_torch_naming_convention,
)


class QDQAutotuner(QDQAutotunerBase):
Expand Down Expand Up @@ -94,13 +98,18 @@ def _search_regions(self) -> None:
- Phase 2: Top-down refinement creating hierarchical structure
"""
logger.info("Discovering optimization regions")
search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = search.search_regions()
self._reassign_region_ids(self.regions)
if check_torch_naming_convention(self.graph):
torch_search = TorchRegionBuilder(self.graph)
self.regions = torch_search.build_regions(linearize=True, only_quantizable=True)
self._reassign_region_ids(self.regions)
else:
Comment thread
willg-nv marked this conversation as resolved.
default_search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = default_search.search_regions()
self._reassign_region_ids(self.regions)
logger.debug(f"Found {len(self.regions)} top-level regions")

# Flatten the hierarchy into a list of all regions
Expand Down
Loading