1616"""Hierarchical region discovery and partitioning for ONNX graphs."""
1717
1818import sys
19- from collections import deque
19+ from collections import defaultdict , deque
2020
2121import onnx_graphsurgeon as gs
2222
@@ -190,6 +190,74 @@ def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[i
190190 logger .debug (f"Reachability map complete: avg { avg_reachable :.1f} reachable nodes per node" )
191191 return forward_reachable_nodes_map
192192
193+ def _find_common_reachable_nodes (
194+ self , node_idx : int , branches : list [int ]
195+ ) -> tuple [list [dict ], set [int ]] | None :
196+ """Find common reachable nodes from all branches (potential convergence points).
197+
198+ Used as STEP 1 of convergence detection in _find_converge_nodes.
199+
200+ Args:
201+ node_idx: Index of the divergent node (excluded from common_nodes).
202+ branches: List of branch head node indices.
203+
204+ Returns:
205+ (branch_reachable, common_nodes) if valid; None if no convergence candidates.
206+ """
207+ branch_reachable = [self .forward_reachable_nodes_map .get (b , {}) for b in branches ]
208+
209+ if not branch_reachable :
210+ logger .debug (" No reachable nodes from branches" )
211+ return [], set ()
212+
213+ common_nodes = set .intersection (* [set (r .keys ()) for r in branch_reachable ])
214+ logger .debug (f" { len (common_nodes )} common nodes found" )
215+ common_nodes .discard (node_idx )
216+
217+ if not common_nodes :
218+ logger .debug (" No valid convergence candidates" )
219+ return [], set ()
220+
221+ return branch_reachable , common_nodes
222+
223+ def _evaluate_convergence_candidate (
224+ self ,
225+ candidate_idx : int ,
226+ reachable_from_start : dict ,
227+ branch_reachable : list ,
228+ ) -> tuple [bool , int ]:
229+ r"""Check if a candidate convergence node forms a valid region and return its max distance.
230+
231+ A valid region has no \"escaping\" edges: no node inside the region may reach a node
232+ outside the region before reaching the candidate convergence point.
233+
234+ Args:
235+ candidate_idx: Candidate convergence node index.
236+ reachable_from_start: Forward reachability from the divergent node.
237+ branch_reachable: Per-branch reachability dicts (for max distance).
238+
239+ Returns:
240+ (is_valid, max_distance). max_distance is only meaningful when is_valid is True.
241+ """
242+ region_nodes : set [int ] = set (reachable_from_start .keys ())
243+ reachable_from_candidate = self .forward_reachable_nodes_map .get (candidate_idx , {})
244+ region_nodes = region_nodes - set (reachable_from_candidate .keys ())
245+
246+ for rnode_index in region_nodes :
247+ reachable_from_rnode = self .forward_reachable_nodes_map .get (rnode_index , {})
248+ rnode_to_candidate_distance = reachable_from_rnode .get (candidate_idx , float ("inf" ))
249+ for test_node_idx in reachable_from_rnode :
250+ if test_node_idx in region_nodes :
251+ continue
252+ rnode_to_test_distance = reachable_from_rnode .get (test_node_idx , float ("inf" ))
253+ if any (
254+ d == float ("inf" ) for d in (rnode_to_test_distance , rnode_to_candidate_distance )
255+ ):
256+ return False , 0
257+
258+ max_distance = max (reachable [candidate_idx ] for reachable in branch_reachable )
259+ return True , max_distance
260+
193261 def _find_converge_nodes (self , node_idx : int ) -> tuple [int | None , set [int ]]:
194262 """Find convergence point and intermediate nodes for a divergent node.
195263
@@ -216,67 +284,30 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
216284
217285 logger .debug (f" { len (branches )} unique branches found" )
218286
219- # Need at least 2 branches for convergence to be meaningful
220287 if len (branches ) <= 1 :
221288 logger .debug (" Insufficient branches for convergence" )
222289 return None , set ()
223290
224- # STEP 1: Find Common Reachable Nodes (Potential Convergence Points)
225- branch_reachable = [self .forward_reachable_nodes_map .get (b , {}) for b in branches ]
226-
227- if not branch_reachable :
228- logger .debug (" No reachable nodes from branches" )
291+ branch_reachable , common_nodes = self ._find_common_reachable_nodes (node_idx , branches )
292+ if not branch_reachable or not common_nodes :
229293 return None , set ()
230294
231- common_nodes = set .intersection (* [set (r .keys ()) for r in branch_reachable ])
232- logger .debug (f" { len (common_nodes )} common nodes found" )
233- # Remove the divergent node itself (not a convergence point)
234- common_nodes .discard (node_idx )
235-
236- if not common_nodes :
237- logger .debug (" No valid convergence candidates" )
238- return None , set ()
239-
240- # STEP 2: Select Best Convergence Node with Region Validity Check
295+ # Select Best Convergence Node with Region Validity Check
241296 converge_node_idx : int | None = None
242297 min_max_distance = float ("inf" )
243298
244299 reachable_from_start = self .forward_reachable_nodes_map .get (node_idx , {})
245300
246- # Evaluate each candidate convergence point
247301 for candidate_idx in common_nodes :
248- # Define the potential region: nodes between start and candidate
249- region_nodes : set [int ] = reachable_from_start .keys ()
250- reachable_from_candidate = self .forward_reachable_nodes_map .get (candidate_idx , {})
251- region_nodes = region_nodes - reachable_from_candidate .keys ()
252-
253- valid = True
254- for rnode_index in region_nodes :
255- reachable_from_rnode = self .forward_reachable_nodes_map .get (rnode_index , {})
256- rnode_to_candidate_distance = reachable_from_rnode .get (candidate_idx , float ("inf" ))
257- for test_node_idx in reachable_from_rnode :
258- # Skip nodes that are inside the region (they're fine)
259- if test_node_idx in region_nodes :
260- continue
261- # test_node is OUTSIDE the region. Check if it's "escaping"
262- # An escaping edge: region_node reaches test_node BEFORE candidate
263- rnode_to_test_distance = reachable_from_rnode .get (test_node_idx , float ("inf" ))
264- # If either distance is infinite, region is broken
265- # (indicates disconnected components or unreachable convergence)
266- if any (
267- d == float ("inf" )
268- for d in (rnode_to_test_distance , rnode_to_candidate_distance )
269- ):
270- valid = False
271- break
272- if not valid :
273- break
302+ valid , max_distance = self ._evaluate_convergence_candidate (
303+ candidate_idx , reachable_from_start , branch_reachable
304+ )
274305 if not valid :
275306 continue
276- max_distance = max (reachable [candidate_idx ] for reachable in branch_reachable )
277307 if max_distance < min_max_distance :
278308 min_max_distance = max_distance
279309 converge_node_idx = candidate_idx
310+
280311 # If no valid convergence found, this divergence has no convergence
281312 if converge_node_idx is None :
282313 logger .debug (" No valid convergence found" )
@@ -286,7 +317,8 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
286317 logger .debug (
287318 f" Convergence at node { converge_node_idx } ({ converge_node .op } ), distance { min_max_distance } "
288319 )
289- # STEP 3: Compute All Nodes Between Divergence and Convergence
320+
321+ # Compute All Nodes Between Divergence and Convergence
290322 visited_nodes : set [int ] = set ()
291323 for candidate_idx in reachable_from_start :
292324 if candidate_idx == converge_node_idx :
@@ -604,27 +636,12 @@ def _build_small_converged_region(
604636 def _build_region_from_node (self , node_idx : int ):
605637 """Process a single node and create appropriate region(s) based on its pattern.
606638
607- This is the core dispatch method that determines how to handle each node
608- based on whether it's divergent (branches) or sequential. Implements the
609- three pattern recognition strategies described in the class documentation.
639+ This is the core dispatch method that determines how to handle each node based on whether
640+ it's divergent (branches) or sequential.
610641
611- **Pattern 1: Divergent with Convergence (Ideal Case)**
612- Creates a complete "funnel" region capturing parallel branches:
613- - Example: ResNet skip connection (Conv branch + identity → Add)
614- - Condition: converge_node found AND distance < DEFAULT_MAX_STEPS
615- - Result: One region containing all nodes between divergence and convergence
616-
617- **Pattern 2: Divergent without Convergence (Boundary Case)**
618- Creates a single-node "orphan" region:
619- - Example: Final layer that branches to multiple outputs
620- - Condition: No convergence found OR convergence too far away
621- - Result: Region containing only the divergent node
622-
623- **Pattern 3: Sequential Chain (Common Case)**
624- Creates a region containing linear sequence:
625- - Example: Conv → BN → ReLU → MaxPool
626- - Condition: Node is not divergent
627- - Result: Region containing the full non-divergent chain
642+ - Pattern 1: Divergent with Convergence (Ideal Case)
643+ - Pattern 2: Divergent without Convergence (Boundary Case)
644+ - Pattern 3: Sequential Chain (Common Case)
628645
629646 Args:
630647 node_idx: Index of node to process
@@ -790,12 +807,10 @@ def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Regio
790807 Returns:
791808 Mapping from tensor names to regions that consume them
792809 """
793- region_usage_map : dict [str , list [Region ]] = {}
810+ region_usage_map : dict [str , list [Region ]] = defaultdict ( list )
794811 for region in regions :
795- for tensor_name in region .inputs :
796- if tensor_name not in region_usage_map :
797- region_usage_map [tensor_name ] = []
798- region_usage_map [tensor_name ].append (region )
812+ for input_tensor in region .inputs :
813+ region_usage_map [input_tensor ].append (region )
799814 return region_usage_map
800815
801816 def _split_sequence_regions (self , root : Region ) -> list [Region ]:
@@ -954,29 +969,30 @@ def _merge_converged_regions(self, root: Region):
954969 def build_composite_region (self ) -> Region :
955970 """Refine a flat region into a hierarchical COMPOSITE region."""
956971 # merge converged regions into composite regions
957- self . regions = self ._merge_converged_regions (self .root )
972+ regions = self ._merge_converged_regions (self .root )
958973 # split sequence regions into smaller regions
959974 result_regions : list [Region ] = []
960- for region in self . regions :
975+ for region in regions :
961976 result_regions .extend (self ._split_sequence_regions (region ))
962977 for region in result_regions :
963978 self .compute_region_boundaries (region , include_constant = True )
964- self . regions = result_regions
979+ regions = result_regions
965980 # merge all regions into a single composite region
966- if len (self . regions ) > 1 :
981+ if len (regions ) > 1 :
967982 composite = Region (
968983 region_id = self .next_region_id ,
969984 level = self .root .level ,
970985 region_type = RegionType .COMPOSITE ,
971986 )
972987 self .next_region_id += 1
973- self . regions = sorted (
974- self . regions , key = lambda x : RegionPattern .from_region (x , self .graph ).signature
988+ regions = sorted (
989+ regions , key = lambda x : RegionPattern .from_region (x , self .graph ).signature
975990 )
976- for region in self . regions :
991+ for region in regions :
977992 composite .add_child (region )
978993 self .compute_region_boundaries (composite )
979- self .regions = [composite ]
994+ regions = [composite ]
995+ self .regions = regions
980996 return self .regions [0 ]
981997
982998
0 commit comments