Skip to content

Commit 23278a4

Browse files
authored
Integrate Automated QDQ placement tool - Part 1 (#701)
## What does this PR do? **Type of change:** new feature **Overview:** This PR integrates an automatical QDQ placment tool into ModelOpt. This PR is the 1/4 parts of the change, it contains the following changes: 1. Defines common types: Region, RegionType, Error types 2. Defines InsertionPoints (the logical localtion to place QDQ pairs), InsertionScheme (a set of insertion points) 3. Unit tests for new types Part 1: #701 Part 2: #702 Part 3: #703 Part 4: #704 ## Usage ```python # Region type usage: region = Region(region_id=1, level=0, region_type=RegionType.LEAF) assert region.get_id() == 1 assert region.get_level() == 0 region.add_node(1) # 1 is the index of ONNX graph node ... point = NodeInputInsertionPoint(node_index=0, input_index=2) assert point.node_index == 0 # relative node index in region assert point.input_index == 2 # relative input tensor index in specific node resolved = point.resolve(region, graph) ... ``` ## Testing Implement unit tests, all tests could get passed. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No, document change will be included in part 4. - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No, this could be done when all parts of the change are merged. ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added foundational autotuner infrastructure for quantization optimization, including region hierarchies and insertion scheme management. * Introduced insertion point system for managing quantize/dequantize operation placement across ONNX graph regions. * Added utility functions for tensor consumer mapping and boolean operation identification. * **Tests** * Added comprehensive test coverage for autotuner components, insertion points, and region management. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 2a46753 commit 23278a4

File tree

6 files changed

+1995
-1
lines changed

6 files changed

+1995
-1
lines changed

modelopt/onnx/op_types.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str):
9696
]
9797

9898

99-
def get_copy_ops():
99+
def get_copy_ops() -> list[str]:
100100
"""Returns list of copy operators."""
101101
return [
102102
"Flatten",
@@ -303,3 +303,67 @@ def is_data_dependent_shape_op(op_type: str):
303303
"NonZero",
304304
"RoiAlign",
305305
]
306+
307+
308+
def get_bool_ops():
309+
"""Returns set of bool operations."""
310+
return {
311+
"Not",
312+
"And",
313+
"Or",
314+
"Xor",
315+
}
316+
317+
318+
def get_bitwise_ops():
319+
"""Returns set of bitwise operations."""
320+
return {
321+
"BitwiseAnd",
322+
"BitwiseOr",
323+
"BitwiseXor",
324+
"BitShift",
325+
}
326+
327+
328+
def get_value_check_ops():
329+
"""Returns set of value checking operations."""
330+
return {
331+
"IsNaN",
332+
"IsInf",
333+
"Sign",
334+
"Abs",
335+
}
336+
337+
338+
def get_comparison_ops():
339+
"""Returns set of comparison operations."""
340+
return {
341+
"Equal",
342+
"Greater",
343+
"GreaterOrEqual",
344+
"Less",
345+
"LessOrEqual",
346+
}
347+
348+
349+
def get_conditional_ops():
350+
"""Returns set of conditional operations."""
351+
return {
352+
"Where",
353+
}
354+
355+
356+
def get_aggregation_ops():
357+
"""Returns set of aggregation operations."""
358+
return {
359+
"All",
360+
"Any",
361+
}
362+
363+
364+
def get_set_ops():
365+
"""Returns set of set/search operations."""
366+
return {
367+
"Unique",
368+
"NonZero",
369+
}
Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Common data structures and types for the QDQ Autotuner."""
17+
18+
import hashlib
19+
from dataclasses import dataclass, field
20+
from enum import Enum
21+
from typing import Any
22+
23+
from modelopt.onnx.logging_config import logger
24+
from modelopt.onnx.quantization.autotune.insertion_points import (
25+
ChildRegionInputInsertionPoint,
26+
ChildRegionOutputInsertionPoint,
27+
NodeInputInsertionPoint,
28+
)
29+
30+
31+
class AutotunerError(Exception):
32+
"""Base exception for autotuner-related errors."""
33+
34+
35+
class AutotunerNotInitializedError(AutotunerError):
36+
"""Exception raised when autotuner is used without initialization."""
37+
38+
39+
class InvalidSchemeError(AutotunerError):
40+
"""Exception raised when an invalid scheme is referenced."""
41+
42+
43+
class RegionType(Enum):
44+
"""Region type enumeration for hierarchical graph structure.
45+
46+
- LEAF: Atomic region containing direct nodes with no child regions
47+
- COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes)
48+
- ROOT: Top-level region encompassing the entire computation graph
49+
"""
50+
51+
LEAF = "LEAF"
52+
COMPOSITE = "COMPOSITE"
53+
ROOT = "ROOT"
54+
55+
56+
class Region:
57+
"""A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion.
58+
59+
Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions
60+
contain child regions, and LEAF regions contain only nodes. Each region tracks
61+
its direct nodes, input/output tensors, and a pattern signature for matching
62+
regions with identical structure.
63+
"""
64+
65+
def __init__(self, region_id: int, level: int, region_type: RegionType):
66+
"""Initialize a new region.
67+
68+
Args:
69+
region_id: Unique identifier within the region hierarchy
70+
level: Hierarchical level (0 = leaf, higher = more composite)
71+
region_type: Type classification (LEAF, COMPOSITE, or ROOT)
72+
"""
73+
self.id = region_id
74+
self.level = level
75+
self.type = region_type
76+
self.parent: Region | None = None
77+
self.children: list[Region] = []
78+
self.nodes: set[int] = set()
79+
self.inputs: list[str] = []
80+
self.outputs: list[str] = []
81+
self.metadata: dict[str, str] = {}
82+
83+
def get_children(self, *, sort: bool = False) -> list["Region"]:
84+
"""Get all child regions. If sort is True, sort the children by level and size.
85+
86+
Args:
87+
sort: Whether to sort the children by level and size
88+
89+
Returns:
90+
List of child regions
91+
"""
92+
if sort:
93+
return sorted(
94+
self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants())
95+
)
96+
return self.children
97+
98+
def remove_child(self, child: "Region") -> bool:
99+
"""Remove a child region from this region's children list."""
100+
if child not in self.children:
101+
return False
102+
self.children.remove(child)
103+
if child.parent and child.parent.id == self.id:
104+
child.parent = None
105+
return True
106+
107+
def add_child(self, child: "Region") -> None:
108+
"""Add a child sub-region."""
109+
if child.id == self.id:
110+
logger.warning(f"Cannot add region {self.id} as its own child")
111+
return
112+
113+
if self.is_descendant_of(child):
114+
logger.warning(
115+
f"Cycle detected: region {self.id} is already a descendant of region {child.id}"
116+
)
117+
return
118+
119+
if child.parent is not None and child.parent.id != self.id:
120+
old_parent_id = child.parent.id
121+
logger.debug(
122+
f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}"
123+
)
124+
child.parent.remove_child(child)
125+
126+
if any(c.id == child.id for c in self.children):
127+
logger.debug(f"Region {child.id} already child of {self.id}")
128+
return
129+
130+
self.children.append(child)
131+
child.parent = self
132+
133+
def is_descendant_of(self, potential_ancestor: "Region") -> bool:
134+
"""Check if this region is a descendant of potential_ancestor."""
135+
visited = set()
136+
current = self.parent
137+
while current:
138+
if current.id in visited:
139+
return False
140+
visited.add(current.id)
141+
if current.id == potential_ancestor.id:
142+
return True
143+
current = current.parent
144+
return False
145+
146+
def get_nodes(self, *, sort: bool = False) -> list[int]:
147+
"""Get direct node indices in this region only."""
148+
if sort:
149+
return sorted(self.nodes)
150+
return list(self.nodes)
151+
152+
def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]:
153+
"""Get all node indices recursively, including descendants."""
154+
if _visited is None:
155+
_visited = set()
156+
157+
# Detect cycles
158+
assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal"
159+
160+
_visited.add(self.id)
161+
all_nodes = set(self.nodes)
162+
for child in self.children:
163+
all_nodes.update(child.get_region_nodes_and_descendants(_visited))
164+
return all_nodes
165+
166+
def contains_node(self, node_index: int) -> bool:
167+
"""Check if region contains a specific node (direct only)."""
168+
return node_index in self.nodes
169+
170+
def contains_node_within_region_and_descendants(self, node_index: int) -> bool:
171+
"""Check if region contains a node recursively."""
172+
return node_index in self.get_region_nodes_and_descendants()
173+
174+
def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int:
175+
"""Get total node count recursively including all descendants."""
176+
if _visited is None:
177+
_visited = set()
178+
179+
# Detect cycles
180+
assert self.id not in _visited, (
181+
f"Cycle detected in region {self.id} during size calculation"
182+
)
183+
184+
_visited.add(self.id)
185+
total = len(self.nodes)
186+
for child in self.children:
187+
total += child.get_size_of_region_and_descendants(_visited)
188+
return total
189+
190+
def merge(self, other: "Region") -> None:
191+
"""Merge another region into this one."""
192+
if not other:
193+
return
194+
self.nodes.update(other.nodes)
195+
for child in other.children:
196+
self.add_child(child)
197+
198+
def __repr__(self) -> str:
199+
type_str = self.type.value
200+
return (
201+
f"Region[id={self.id}, level={self.level}, type={type_str}, "
202+
f"nodes={len(self.nodes)}, children={len(self.children)}, "
203+
f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]"
204+
)
205+
206+
207+
@dataclass
208+
class InsertionScheme:
209+
"""Complete Q/DQ insertion specification for a region pattern.
210+
211+
An InsertionScheme defines a complete Q/DQ configuration for a pattern,
212+
combining both node-level and region-level insertion points. The scheme
213+
is applied to all regions matching the pattern.
214+
"""
215+
216+
node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list)
217+
child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list)
218+
region_outputs: list[ChildRegionOutputInsertionPoint] = field(default_factory=list)
219+
latency_ms: float = float("inf")
220+
error: bool = False
221+
profile_timestamp: str | None = None
222+
223+
@property
224+
def hash(self) -> str:
225+
"""Compute deterministic hash for scheme identity.
226+
227+
The hash uniquely identifies this scheme configuration based on its
228+
insertion points. Two schemes with identical insertion points produce
229+
the same hash, regardless of their measured latencies.
230+
"""
231+
sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs])
232+
sorted_regions = sorted(
233+
[(pt.region_index, pt.input_index) for pt in self.child_region_inputs]
234+
)
235+
sorted_region_outputs = sorted(
236+
[(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs]
237+
)
238+
239+
hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}"
240+
241+
return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32]
242+
243+
@property
244+
def is_empty(self) -> bool:
245+
"""Check if this is a baseline scheme with no Q/DQ insertions."""
246+
return not self.node_inputs and not self.child_region_inputs and not self.region_outputs
247+
248+
@property
249+
def is_profiled(self) -> bool:
250+
"""Check if this scheme has been profiled (measured).
251+
252+
A scheme is considered profiled if it has been measured (has non-infinite latency)
253+
or has encountered an error during measurement.
254+
"""
255+
return self.error or self.latency_ms != float("inf")
256+
257+
def to_dict(self) -> dict[str, Any]:
258+
"""Convert to dictionary for serialization."""
259+
return {
260+
"latency_ms": self.latency_ms,
261+
"error": self.error,
262+
"profile_timestamp": self.profile_timestamp,
263+
"nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs],
264+
"child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs],
265+
"region_outputs": [pt.to_dict() for pt in self.region_outputs],
266+
"hash": self.hash,
267+
}
268+
269+
@classmethod
270+
def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme":
271+
"""Create InsertionScheme from serialized dictionary."""
272+
scheme = cls()
273+
scheme.latency_ms = data.get("latency_ms", float("inf"))
274+
scheme.error = data.get("error", False)
275+
scheme.profile_timestamp = data.get("profile_timestamp")
276+
277+
scheme.node_inputs = [
278+
NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", [])
279+
]
280+
scheme.child_region_inputs = [
281+
ChildRegionInputInsertionPoint.from_dict(pt)
282+
for pt in data.get("child_region_inputs", [])
283+
]
284+
scheme.region_outputs = [
285+
ChildRegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", [])
286+
]
287+
288+
return scheme
289+
290+
def distance(self, other: "InsertionScheme") -> int:
291+
"""Compute edit distance between this scheme and another scheme.
292+
293+
The edit distance is the minimum number of add/remove operations needed
294+
to transform this scheme into the other scheme. This is computed as the
295+
symmetric difference between the insertion point sets.
296+
297+
Args:
298+
other: InsertionScheme to compare against
299+
300+
Returns:
301+
Total edit distance (number of add + remove operations)
302+
"""
303+
return (
304+
len(set(self.node_inputs).symmetric_difference(other.node_inputs))
305+
+ len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs))
306+
+ len(set(self.region_outputs).symmetric_difference(other.region_outputs))
307+
)
308+
309+
def __str__(self) -> str:
310+
"""String representation for debugging."""
311+
error_str = ", error=True" if self.error else ""
312+
return (
313+
f"InsertionScheme(node_insertions={len(self.node_inputs)}, "
314+
f"region_insertions={len(self.child_region_inputs)}, "
315+
f"region_output_insertions={len(self.region_outputs)}, "
316+
f"latency={self.latency_ms:.3f}ms{error_str})"
317+
)

0 commit comments

Comments
 (0)