Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/autotune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from .region_pattern import RegionPattern
from .region_search import CombinedRegionSearch
from .torch_region_builder import TorchRegionBuilder
except ImportError as e:
from modelopt.onnx.logging_config import logger

Expand Down Expand Up @@ -73,6 +74,7 @@
"RegionType",
"ResolvedInsertionPoint",
"TensorRTPyBenchmark",
"TorchRegionBuilder",
"TrtExecBenchmark",
]
__all__.extend(name for name in _OPTIONAL_EXPORTS if name in globals())
42 changes: 26 additions & 16 deletions modelopt/onnx/quantization/autotune/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from modelopt.onnx.quantization.autotune.autotuner_base import QDQAutotunerBase
from modelopt.onnx.quantization.autotune.common import Config, PatternCache, Region, RegionType
from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch
from modelopt.onnx.quantization.autotune.torch_region_builder import (
TorchRegionBuilder,
check_torch_naming_convention,
)


class QDQAutotuner(QDQAutotunerBase):
Expand Down Expand Up @@ -94,22 +98,28 @@ def _search_regions(self) -> None:
- Phase 2: Top-down refinement creating hierarchical structure
"""
logger.info("Discovering optimization regions")
search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = search.search_regions()
self._reassign_region_ids(self.regions)
logger.debug(f"Found {len(self.regions)} top-level regions")

# Flatten the hierarchy into a list of all regions
all_regions = []
for region in self.regions:
all_regions.extend(QDQAutotuner._visit_region_recursively(region))

all_regions.sort(key=lambda r: r.type != RegionType.LEAF)
self.regions = all_regions
if check_torch_naming_convention(self.graph):
torch_search = TorchRegionBuilder(self.graph)
# linearize=True returns leaves + innermost composites, leaves first
self.regions = torch_search.build_regions(linearize=True, only_quantizable=True)
for i, region in enumerate(self.regions):
region.id = i
else:
Comment thread
willg-nv marked this conversation as resolved.
default_search = CombinedRegionSearch(
self.graph,
maximum_sequence_region_size=self.config.maximum_sequence_region_size,
minimum_topdown_search_size=self.config.minimum_topdown_search_size,
)
self.regions = default_search.search_regions()
self._reassign_region_ids(self.regions)

# Flatten the hierarchy into a list of all regions
all_regions = []
for region in self.regions:
all_regions.extend(QDQAutotuner._visit_region_recursively(region))

all_regions.sort(key=lambda r: r.type != RegionType.LEAF)
self.regions = all_regions

type_counts = Counter(r.type for r in self.regions)
logger.info(
Expand Down
Loading