Skip to content

Commit f67d957

Browse files
committed
resolve comments
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 54e3ddc commit f67d957

2 files changed

Lines changed: 145 additions & 197 deletions

File tree

modelopt/onnx/quantization/autotune/region_search.py

Lines changed: 95 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""Hierarchical region discovery and partitioning for ONNX graphs."""
1717

1818
import sys
19-
from collections import deque
19+
from collections import defaultdict, deque
2020

2121
import onnx_graphsurgeon as gs
2222

@@ -190,6 +190,74 @@ def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[i
190190
logger.debug(f"Reachability map complete: avg {avg_reachable:.1f} reachable nodes per node")
191191
return forward_reachable_nodes_map
192192

193+
def _find_common_reachable_nodes(
194+
self, node_idx: int, branches: list[int]
195+
) -> tuple[list[dict], set[int]] | None:
196+
"""Find common reachable nodes from all branches (potential convergence points).
197+
198+
Used as STEP 1 of convergence detection in _find_converge_nodes.
199+
200+
Args:
201+
node_idx: Index of the divergent node (excluded from common_nodes).
202+
branches: List of branch head node indices.
203+
204+
Returns:
205+
(branch_reachable, common_nodes) if valid; None if no convergence candidates.
206+
"""
207+
branch_reachable = [self.forward_reachable_nodes_map.get(b, {}) for b in branches]
208+
209+
if not branch_reachable:
210+
logger.debug(" No reachable nodes from branches")
211+
return [], set()
212+
213+
common_nodes = set.intersection(*[set(r.keys()) for r in branch_reachable])
214+
logger.debug(f" {len(common_nodes)} common nodes found")
215+
common_nodes.discard(node_idx)
216+
217+
if not common_nodes:
218+
logger.debug(" No valid convergence candidates")
219+
return [], set()
220+
221+
return branch_reachable, common_nodes
222+
223+
def _evaluate_convergence_candidate(
224+
self,
225+
candidate_idx: int,
226+
reachable_from_start: dict,
227+
branch_reachable: list,
228+
) -> tuple[bool, int]:
229+
r"""Check if a candidate convergence node forms a valid region and return its max distance.
230+
231+
A valid region has no \"escaping\" edges: no node inside the region may reach a node
232+
outside the region before reaching the candidate convergence point.
233+
234+
Args:
235+
candidate_idx: Candidate convergence node index.
236+
reachable_from_start: Forward reachability from the divergent node.
237+
branch_reachable: Per-branch reachability dicts (for max distance).
238+
239+
Returns:
240+
(is_valid, max_distance). max_distance is only meaningful when is_valid is True.
241+
"""
242+
region_nodes: set[int] = set(reachable_from_start.keys())
243+
reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {})
244+
region_nodes = region_nodes - set(reachable_from_candidate.keys())
245+
246+
for rnode_index in region_nodes:
247+
reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {})
248+
rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf"))
249+
for test_node_idx in reachable_from_rnode:
250+
if test_node_idx in region_nodes:
251+
continue
252+
rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf"))
253+
if any(
254+
d == float("inf") for d in (rnode_to_test_distance, rnode_to_candidate_distance)
255+
):
256+
return False, 0
257+
258+
max_distance = max(reachable[candidate_idx] for reachable in branch_reachable)
259+
return True, max_distance
260+
193261
def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
194262
"""Find convergence point and intermediate nodes for a divergent node.
195263
@@ -216,67 +284,30 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
216284

217285
logger.debug(f" {len(branches)} unique branches found")
218286

219-
# Need at least 2 branches for convergence to be meaningful
220287
if len(branches) <= 1:
221288
logger.debug(" Insufficient branches for convergence")
222289
return None, set()
223290

224-
# STEP 1: Find Common Reachable Nodes (Potential Convergence Points)
225-
branch_reachable = [self.forward_reachable_nodes_map.get(b, {}) for b in branches]
226-
227-
if not branch_reachable:
228-
logger.debug(" No reachable nodes from branches")
291+
branch_reachable, common_nodes = self._find_common_reachable_nodes(node_idx, branches)
292+
if not branch_reachable or not common_nodes:
229293
return None, set()
230294

231-
common_nodes = set.intersection(*[set(r.keys()) for r in branch_reachable])
232-
logger.debug(f" {len(common_nodes)} common nodes found")
233-
# Remove the divergent node itself (not a convergence point)
234-
common_nodes.discard(node_idx)
235-
236-
if not common_nodes:
237-
logger.debug(" No valid convergence candidates")
238-
return None, set()
239-
240-
# STEP 2: Select Best Convergence Node with Region Validity Check
295+
# Select Best Convergence Node with Region Validity Check
241296
converge_node_idx: int | None = None
242297
min_max_distance = float("inf")
243298

244299
reachable_from_start = self.forward_reachable_nodes_map.get(node_idx, {})
245300

246-
# Evaluate each candidate convergence point
247301
for candidate_idx in common_nodes:
248-
# Define the potential region: nodes between start and candidate
249-
region_nodes: set[int] = reachable_from_start.keys()
250-
reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {})
251-
region_nodes = region_nodes - reachable_from_candidate.keys()
252-
253-
valid = True
254-
for rnode_index in region_nodes:
255-
reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {})
256-
rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf"))
257-
for test_node_idx in reachable_from_rnode:
258-
# Skip nodes that are inside the region (they're fine)
259-
if test_node_idx in region_nodes:
260-
continue
261-
# test_node is OUTSIDE the region. Check if it's "escaping"
262-
# An escaping edge: region_node reaches test_node BEFORE candidate
263-
rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf"))
264-
# If either distance is infinite, region is broken
265-
# (indicates disconnected components or unreachable convergence)
266-
if any(
267-
d == float("inf")
268-
for d in (rnode_to_test_distance, rnode_to_candidate_distance)
269-
):
270-
valid = False
271-
break
272-
if not valid:
273-
break
302+
valid, max_distance = self._evaluate_convergence_candidate(
303+
candidate_idx, reachable_from_start, branch_reachable
304+
)
274305
if not valid:
275306
continue
276-
max_distance = max(reachable[candidate_idx] for reachable in branch_reachable)
277307
if max_distance < min_max_distance:
278308
min_max_distance = max_distance
279309
converge_node_idx = candidate_idx
310+
280311
# If no valid convergence found, this divergence has no convergence
281312
if converge_node_idx is None:
282313
logger.debug(" No valid convergence found")
@@ -286,7 +317,8 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
286317
logger.debug(
287318
f" Convergence at node {converge_node_idx} ({converge_node.op}), distance {min_max_distance}"
288319
)
289-
# STEP 3: Compute All Nodes Between Divergence and Convergence
320+
321+
# Compute All Nodes Between Divergence and Convergence
290322
visited_nodes: set[int] = set()
291323
for candidate_idx in reachable_from_start:
292324
if candidate_idx == converge_node_idx:
@@ -604,27 +636,12 @@ def _build_small_converged_region(
604636
def _build_region_from_node(self, node_idx: int):
605637
"""Process a single node and create appropriate region(s) based on its pattern.
606638
607-
This is the core dispatch method that determines how to handle each node
608-
based on whether it's divergent (branches) or sequential. Implements the
609-
three pattern recognition strategies described in the class documentation.
639+
This is the core dispatch method that determines how to handle each node based on whether
640+
it's divergent (branches) or sequential.
610641
611-
**Pattern 1: Divergent with Convergence (Ideal Case)**
612-
Creates a complete "funnel" region capturing parallel branches:
613-
- Example: ResNet skip connection (Conv branch + identity → Add)
614-
- Condition: converge_node found AND distance < DEFAULT_MAX_STEPS
615-
- Result: One region containing all nodes between divergence and convergence
616-
617-
**Pattern 2: Divergent without Convergence (Boundary Case)**
618-
Creates a single-node "orphan" region:
619-
- Example: Final layer that branches to multiple outputs
620-
- Condition: No convergence found OR convergence too far away
621-
- Result: Region containing only the divergent node
622-
623-
**Pattern 3: Sequential Chain (Common Case)**
624-
Creates a region containing linear sequence:
625-
- Example: Conv → BN → ReLU → MaxPool
626-
- Condition: Node is not divergent
627-
- Result: Region containing the full non-divergent chain
642+
- Pattern 1: Divergent with Convergence (Ideal Case)
643+
- Pattern 2: Divergent without Convergence (Boundary Case)
644+
- Pattern 3: Sequential Chain (Common Case)
628645
629646
Args:
630647
node_idx: Index of node to process
@@ -790,12 +807,10 @@ def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Regio
790807
Returns:
791808
Mapping from tensor names to regions that consume them
792809
"""
793-
region_usage_map: dict[str, list[Region]] = {}
810+
region_usage_map: dict[str, list[Region]] = defaultdict(list)
794811
for region in regions:
795-
for tensor_name in region.inputs:
796-
if tensor_name not in region_usage_map:
797-
region_usage_map[tensor_name] = []
798-
region_usage_map[tensor_name].append(region)
812+
for input_tensor in region.inputs:
813+
region_usage_map[input_tensor].append(region)
799814
return region_usage_map
800815

801816
def _split_sequence_regions(self, root: Region) -> list[Region]:
@@ -954,29 +969,30 @@ def _merge_converged_regions(self, root: Region):
954969
def build_composite_region(self) -> Region:
955970
"""Refine a flat region into a hierarchical COMPOSITE region."""
956971
# merge converged regions into composite regions
957-
self.regions = self._merge_converged_regions(self.root)
972+
regions = self._merge_converged_regions(self.root)
958973
# split sequence regions into smaller regions
959974
result_regions: list[Region] = []
960-
for region in self.regions:
975+
for region in regions:
961976
result_regions.extend(self._split_sequence_regions(region))
962977
for region in result_regions:
963978
self.compute_region_boundaries(region, include_constant=True)
964-
self.regions = result_regions
979+
regions = result_regions
965980
# merge all regions into a single composite region
966-
if len(self.regions) > 1:
981+
if len(regions) > 1:
967982
composite = Region(
968983
region_id=self.next_region_id,
969984
level=self.root.level,
970985
region_type=RegionType.COMPOSITE,
971986
)
972987
self.next_region_id += 1
973-
self.regions = sorted(
974-
self.regions, key=lambda x: RegionPattern.from_region(x, self.graph).signature
988+
regions = sorted(
989+
regions, key=lambda x: RegionPattern.from_region(x, self.graph).signature
975990
)
976-
for region in self.regions:
991+
for region in regions:
977992
composite.add_child(region)
978993
self.compute_region_boundaries(composite)
979-
self.regions = [composite]
994+
regions = [composite]
995+
self.regions = regions
980996
return self.regions[0]
981997

982998

0 commit comments

Comments
 (0)