@@ -156,11 +156,8 @@ def _compute_forward_reachable_nodes(
156156 current_node_idx , distance = queue .popleft ()
157157 if distance >= max_steps :
158158 continue
159- current_node = self .graph .nodes [current_node_idx ]
160- for output in current_node .outputs :
161- if output .name not in self .tensor_users_map :
162- continue
163- for next_node_idx in self .tensor_users_map [output .name ]:
159+ for output in self .graph .nodes [current_node_idx ].outputs :
160+ for next_node_idx in self .tensor_users_map .get (output .name , ()):
164161 if next_node_idx not in reachable :
165162 reachable [next_node_idx ] = distance + 1
166163 queue .append ((next_node_idx , distance + 1 ))
@@ -213,8 +210,7 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
213210
214211 branches : list [int ] = []
215212 for output in node .outputs :
216- if output .name in self .tensor_users_map :
217- branches .extend (self .tensor_users_map [output .name ])
213+ branches .extend (self .tensor_users_map .get (output .name , []))
218214
219215 branches = list (dict .fromkeys (branches ))
220216
@@ -250,12 +246,11 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
250246 # Evaluate each candidate convergence point
251247 for candidate_idx in common_nodes :
252248 # Define the potential region: nodes between start and candidate
253- region_nodes : set [int ] = set ()
254- region_nodes .update (set (reachable_from_start .keys ()))
249+ region_nodes : set [int ] = reachable_from_start .keys ()
255250 reachable_from_candidate = self .forward_reachable_nodes_map .get (candidate_idx , {})
256- region_nodes . difference_update ( set ( reachable_from_candidate .keys ()) )
251+ region_nodes = region_nodes - reachable_from_candidate .keys ()
257252
258- broken_region = False
253+ valid = True
259254 for rnode_index in region_nodes :
260255 reachable_from_rnode = self .forward_reachable_nodes_map .get (rnode_index , {})
261256 rnode_to_candidate_distance = reachable_from_rnode .get (candidate_idx , float ("inf" ))
@@ -268,24 +263,17 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
268263 rnode_to_test_distance = reachable_from_rnode .get (test_node_idx , float ("inf" ))
269264 # If either distance is infinite, region is broken
270265 # (indicates disconnected components or unreachable convergence)
271- if rnode_to_test_distance == float (
272- "inf"
273- ) or rnode_to_candidate_distance == float ("inf" ):
274- broken_region = True
275- break
276- # If test_node is closer than candidate, we have an escape!
277- # This means computation flows OUT of region before converging
278- if rnode_to_test_distance < rnode_to_candidate_distance :
279- broken_region = True
266+ if any (
267+ d == float ("inf" )
268+ for d in (rnode_to_test_distance , rnode_to_candidate_distance )
269+ ):
270+ valid = False
280271 break
281- if broken_region :
272+ if not valid :
282273 break
283- # Skip this candidate if region is invalid
284- if broken_region :
274+ if not valid :
285275 continue
286- # Valid candidate! Check if it's the nearest one
287276 max_distance = max (reachable [candidate_idx ] for reachable in branch_reachable )
288-
289277 if max_distance < min_max_distance :
290278 min_max_distance = max_distance
291279 converge_node_idx = candidate_idx
@@ -384,57 +372,41 @@ def print_tree(
384372 max_items : int = DEFAULT_MAX_NODES_TO_SHOW ,
385373 file = None ,
386374 ) -> None :
387- """Print hierarchical region tree in human-readable text format.
388-
389- Recursively prints the region hierarchy with indentation showing depth.
390- For each region, displays:
391- - ID, level, and type (LEAF/COMPOSITE/ROOT)
392- - Node counts (direct and recursive)
393- - I/O tensor counts
394- - Sample of nodes in the region (up to max_nodes_to_show)
395- - Child regions (recursively)
396- """
375+ """Print hierarchical region tree in human-readable text format."""
397376 region = region or self .root
398-
399377 file = file or sys .stdout
400378 p = " " * indent
401379
402- def print_items (items , label , formatter = str ):
403- """Print a truncated list of items ."""
380+ def truncated (items , fmt = str ):
381+ """Yield formatted items, truncating with count if needed ."""
404382 items = list (items )
405- print (f"{ p } │ ├─ { label } : { len (items )} " , file = file )
406- for item in items [:max_items ]:
407- print (f"{ p } │ │ - { formatter (item )} " , file = file )
383+ yield from (fmt (x ) for x in items [:max_items ])
408384 if len (items ) > max_items :
409- print ( f" { p } │ │ ... and { len (items ) - max_items } more", file = file )
385+ yield f" ... and { len (items ) - max_items } more"
410386
411- # Header
412- print (
413- f"{ p } ├─ Region { region .id } (Level { region .level } , Type: { region .type .value } )" ,
414- file = file ,
415- )
416-
417- # Counts
418387 direct_nodes = region .get_nodes ()
419388 children = region .get_children ()
389+ # Header + counts
390+ print (
391+ f"{ p } ├─ Region { region .id } (Level { region .level } , Type: { region .type .value } )" , file = file
392+ )
420393 print (f"{ p } │ ├─ Direct nodes: { len (direct_nodes )} " , file = file )
421394 print (f"{ p } │ ├─ Total nodes: { len (region .get_region_nodes_and_descendants ())} " , file = file )
422395 print (f"{ p } │ ├─ Children: { len (children )} " , file = file )
423-
424396 # I/O
425- print_items (region .inputs , "Inputs" )
426- print_items (region .outputs , "Outputs" )
427-
397+ for label , items in [("Inputs" , region .inputs ), ("Outputs" , region .outputs )]:
398+ print (f"{ p } │ ├─ { label } : { len (items )} " , file = file )
399+ for line in truncated (items ):
400+ print (f"{ p } │ │ - { line } " , file = file )
428401 # Direct nodes
429402 if direct_nodes :
430403 print (f"{ p } │\n { p } │ Nodes in this region:" , file = file )
431- for node_idx in sorted (direct_nodes )[:max_items ]:
432- if node_idx < len (self .graph .nodes ):
433- node = self .graph .nodes [node_idx ]
434- print (f"{ p } │ - Node { node_idx } : { node .op } ({ node .name } )" , file = file )
435- if len (direct_nodes ) > max_items :
436- print (f"{ p } │ ... and { len (direct_nodes ) - max_items } more" , file = file )
437404
405+ def node_fmt (i : int ) -> str :
406+ return f"Node { i } : { self .graph .nodes [i ].op } ({ self .graph .nodes [i ].name } )"
407+
408+ for line in truncated (sorted (direct_nodes ), node_fmt ):
409+ print (f"{ p } │ - { line } " , file = file )
438410 # Children
439411 if children :
440412 print (f"{ p } │\n { p } │ Child regions:" , file = file )
0 commit comments