From a1b8f494ca1d9ca00da4ca21a1a0cece879af342 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 16 Mar 2026 13:00:28 -0700 Subject: [PATCH] [ET-VK] Fix exponential blowup in tag_memory_meta_pass repset tracing The trace_node_users_to_constrain_repset DFS previously tracked search depth as a per-branch int counter, allowing each branch of a fan-out to independently explore up to max_trace_search_depth nodes. In transformer-style graphs with heavy fan-out this caused exponential blowup in the number of nodes visited. Replace the int counter with a mutable list containing a single int that is shared by reference across all recursive branches. This limits the TOTAL number of nodes explored per top-level trace call to max_trace_search_depth (16), regardless of fan-out structure. Authored with Claude. Differential Revision: [D96790445](https://our.internmc.facebook.com/intern/diff/D96790445/) [ghstack-poisoned] --- .../vulkan/_passes/tag_memory_meta_pass.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 3bdc30feb7c..1d6ff3ab311 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -132,9 +132,10 @@ def __init__( self.texture_limits = texture_limits self.force_fp16 = force_fp16 - # Magic number to limit "lookahead" when tracing through users of an operator - # to constrain the representation of its arguments/outputs. - self.max_trace_search_depth = None + # Limit the total number of nodes explored when tracing through users of + # an operator to constrain the representation of its arguments/outputs. + # Without a limit, transformer-style graphs cause exponential blowup. + self.max_trace_search_depth = 64 def is_valid_op_node(self, node: Any) -> bool: """ @@ -261,7 +262,7 @@ def constrain_repset_with_user( current_node: torch.fx.Node, arg_i: int, arg_repset: utils.TensorRepSet, - search_depth: int = 0, + search_depth: list[int] | None = None, ) -> utils.TensorRepSet: """ Attempts to constrain `arg_repset` based on the required repset of the argument @@ -305,7 +306,7 @@ def trace_node_users_to_constrain_repset( self, origin_node: torch.fx.Node, repset: utils.TensorRepSet, - search_depth: int = 0, + search_depth: list[int] | None = None, ) -> utils.TensorRepSet: """ For an ambiguous repset, try to constrain the repset by tracing the required @@ -313,9 +314,14 @@ def trace_node_users_to_constrain_repset( that can be used the longest without needing user nodes to insert a transition for its arguments. """ - # Optionally limit the search depth to improve export time + # Optionally limit the total number of nodes explored to improve export + # time. search_depth is a mutable list so that all branches of a fan-out + # share a single counter, preventing exponential blowup. if self.max_trace_search_depth is not None: - if search_depth > self.max_trace_search_depth: + if search_depth is None: + search_depth = [self.max_trace_search_depth] + search_depth[0] -= 1 + if search_depth[0] <= 0: return repset users_to_trace = origin_node.users @@ -339,7 +345,7 @@ def trace_node_users_to_constrain_repset( if arg_i_in_user is not None: repset = self.constrain_repset_with_user( - usage_node, arg_i_in_user, repset, search_depth + 1 + usage_node, arg_i_in_user, repset, search_depth ) if repset.is_constrained():