Skip to content

Commit 54e3ddc

Browse files
committed
resolve comments
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent f676401 commit 54e3ddc

1 file changed

Lines changed: 31 additions & 59 deletions

File tree

modelopt/onnx/quantization/autotune/region_search.py

Lines changed: 31 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,8 @@ def _compute_forward_reachable_nodes(
156156
current_node_idx, distance = queue.popleft()
157157
if distance >= max_steps:
158158
continue
159-
current_node = self.graph.nodes[current_node_idx]
160-
for output in current_node.outputs:
161-
if output.name not in self.tensor_users_map:
162-
continue
163-
for next_node_idx in self.tensor_users_map[output.name]:
159+
for output in self.graph.nodes[current_node_idx].outputs:
160+
for next_node_idx in self.tensor_users_map.get(output.name, ()):
164161
if next_node_idx not in reachable:
165162
reachable[next_node_idx] = distance + 1
166163
queue.append((next_node_idx, distance + 1))
@@ -213,8 +210,7 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
213210

214211
branches: list[int] = []
215212
for output in node.outputs:
216-
if output.name in self.tensor_users_map:
217-
branches.extend(self.tensor_users_map[output.name])
213+
branches.extend(self.tensor_users_map.get(output.name, []))
218214

219215
branches = list(dict.fromkeys(branches))
220216

@@ -250,12 +246,11 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
250246
# Evaluate each candidate convergence point
251247
for candidate_idx in common_nodes:
252248
# Define the potential region: nodes between start and candidate
253-
region_nodes: set[int] = set()
254-
region_nodes.update(set(reachable_from_start.keys()))
249+
region_nodes: set[int] = reachable_from_start.keys()
255250
reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {})
256-
region_nodes.difference_update(set(reachable_from_candidate.keys()))
251+
region_nodes = region_nodes - reachable_from_candidate.keys()
257252

258-
broken_region = False
253+
valid = True
259254
for rnode_index in region_nodes:
260255
reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {})
261256
rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf"))
@@ -268,24 +263,17 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]:
268263
rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf"))
269264
# If either distance is infinite, region is broken
270265
# (indicates disconnected components or unreachable convergence)
271-
if rnode_to_test_distance == float(
272-
"inf"
273-
) or rnode_to_candidate_distance == float("inf"):
274-
broken_region = True
275-
break
276-
# If test_node is closer than candidate, we have an escape!
277-
# This means computation flows OUT of region before converging
278-
if rnode_to_test_distance < rnode_to_candidate_distance:
279-
broken_region = True
266+
if any(
267+
d == float("inf")
268+
for d in (rnode_to_test_distance, rnode_to_candidate_distance)
269+
):
270+
valid = False
280271
break
281-
if broken_region:
272+
if not valid:
282273
break
283-
# Skip this candidate if region is invalid
284-
if broken_region:
274+
if not valid:
285275
continue
286-
# Valid candidate! Check if it's the nearest one
287276
max_distance = max(reachable[candidate_idx] for reachable in branch_reachable)
288-
289277
if max_distance < min_max_distance:
290278
min_max_distance = max_distance
291279
converge_node_idx = candidate_idx
@@ -384,57 +372,41 @@ def print_tree(
384372
max_items: int = DEFAULT_MAX_NODES_TO_SHOW,
385373
file=None,
386374
) -> None:
387-
"""Print hierarchical region tree in human-readable text format.
388-
389-
Recursively prints the region hierarchy with indentation showing depth.
390-
For each region, displays:
391-
- ID, level, and type (LEAF/COMPOSITE/ROOT)
392-
- Node counts (direct and recursive)
393-
- I/O tensor counts
394-
- Sample of nodes in the region (up to max_nodes_to_show)
395-
- Child regions (recursively)
396-
"""
375+
"""Print hierarchical region tree in human-readable text format."""
397376
region = region or self.root
398-
399377
file = file or sys.stdout
400378
p = " " * indent
401379

402-
def print_items(items, label, formatter=str):
403-
"""Print a truncated list of items."""
380+
def truncated(items, fmt=str):
381+
"""Yield formatted items, truncating with count if needed."""
404382
items = list(items)
405-
print(f"{p}│ ├─ {label}: {len(items)}", file=file)
406-
for item in items[:max_items]:
407-
print(f"{p}│ │ - {formatter(item)}", file=file)
383+
yield from (fmt(x) for x in items[:max_items])
408384
if len(items) > max_items:
409-
print(f"{p}│ │ ... and {len(items) - max_items} more", file=file)
385+
yield f"... and {len(items) - max_items} more"
410386

411-
# Header
412-
print(
413-
f"{p}├─ Region {region.id} (Level {region.level}, Type: {region.type.value})",
414-
file=file,
415-
)
416-
417-
# Counts
418387
direct_nodes = region.get_nodes()
419388
children = region.get_children()
389+
# Header + counts
390+
print(
391+
f"{p}├─ Region {region.id} (Level {region.level}, Type: {region.type.value})", file=file
392+
)
420393
print(f"{p}│ ├─ Direct nodes: {len(direct_nodes)}", file=file)
421394
print(f"{p}│ ├─ Total nodes: {len(region.get_region_nodes_and_descendants())}", file=file)
422395
print(f"{p}│ ├─ Children: {len(children)}", file=file)
423-
424396
# I/O
425-
print_items(region.inputs, "Inputs")
426-
print_items(region.outputs, "Outputs")
427-
397+
for label, items in [("Inputs", region.inputs), ("Outputs", region.outputs)]:
398+
print(f"{p}│ ├─ {label}: {len(items)}", file=file)
399+
for line in truncated(items):
400+
print(f"{p}│ │ - {line}", file=file)
428401
# Direct nodes
429402
if direct_nodes:
430403
print(f"{p}\n{p}│ Nodes in this region:", file=file)
431-
for node_idx in sorted(direct_nodes)[:max_items]:
432-
if node_idx < len(self.graph.nodes):
433-
node = self.graph.nodes[node_idx]
434-
print(f"{p}│ - Node {node_idx}: {node.op} ({node.name})", file=file)
435-
if len(direct_nodes) > max_items:
436-
print(f"{p}│ ... and {len(direct_nodes) - max_items} more", file=file)
437404

405+
def node_fmt(i: int) -> str:
406+
return f"Node {i}: {self.graph.nodes[i].op} ({self.graph.nodes[i].name})"
407+
408+
for line in truncated(sorted(direct_nodes), node_fmt):
409+
print(f"{p}│ - {line}", file=file)
438410
# Children
439411
if children:
440412
print(f"{p}\n{p}│ Child regions:", file=file)

0 commit comments

Comments
 (0)