diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 0f8263fa572..6f5cb10f1b2 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -28,11 +28,13 @@ def insert_transition_node( node: torch.fx.Node, arg: torch.fx.Node, arg_node_repr: utils.TensorRepr, -) -> None: +) -> torch.fx.Node: """ Insert a clone node to transition the tensor associated with `arg` to a tensor with the requested representation `arg_node_repr`, and use the cloned node as an argument to `node` instead of `arg`. + + Returns the newly created clone node. """ with graph_module.graph.inserting_before(node): clone_node = graph_module.graph.create_node( @@ -45,6 +47,7 @@ def insert_transition_node( clone_node.meta["spec"].const = False utils.set_node_repr(clone_node, arg_node_repr) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + return clone_node def set_arg_node_repr_or_transition( @@ -53,6 +56,7 @@ def set_arg_node_repr_or_transition( arg_i: int, arg_node_repr: utils.TensorRepr, dirty: bool, + transition_cache: dict | None = None, ) -> bool: """ Does one of following: @@ -60,7 +64,8 @@ def set_arg_node_repr_or_transition( does not currently have a `node_repr` 2. No-op if the current `node_repr` is already the same as the requested represetnation. 3. Insert a transition node to create a copy of the argument with the desired `node_repr` - if the current `node_repr` is different than what is needed. + if the current `node_repr` is different than what is needed. If a transition clone + already exists for the same (source, target_repr) pair, reuse it. """ arg_node = op_node.args[arg_i] @@ -78,15 +83,33 @@ def single_node_impl(node: torch.fx.Node) -> bool: if cur_node_repr == arg_node_repr: return False + assert utils.is_single_tensor_node(node) + + # Check if a transition clone already exists for this (source, target_repr). + cache_key = ( + node, + arg_node_repr.storage_type, + arg_node_repr.memory_layout, + ) + if transition_cache is not None and cache_key in transition_cache: + cached_clone = transition_cache[cache_key] + node.replace_all_uses_with(cached_clone, lambda x, y=op_node: x == y) + if not dirty: + logger.info( + f"[Vulkan Delegate] Reusing transition for {op_node.format_node()}:" + ) + logger.info(f" arg {arg_i} ({node}): reusing {cached_clone}") + return True + if not dirty: logger.info( f"[Vulkan Delegate] Inserting transition(s) for {op_node.format_node()}:" ) - # Existing node representation is different; insert a transition node - # Currently, the transition node insertion logic can only handle single tensor nodes - assert utils.is_single_tensor_node(node) - insert_transition_node(graph_module, op_node, node, arg_node_repr) + clone_node = insert_transition_node(graph_module, op_node, node, arg_node_repr) + + if transition_cache is not None: + transition_cache[cache_key] = clone_node logger.info(f" arg {arg_i} ({node}): ({cur_node_repr}) -> ({arg_node_repr})") @@ -407,7 +430,10 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( - self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node + self, + graph_module: torch.fx.GraphModule, + op_node: torch.fx.Node, + transition_cache: dict | None = None, ) -> None: """ For an operator representated by `op_node`, get the OpRepSets associated with @@ -458,7 +484,12 @@ def set_op_node_tensor_reprs( if isinstance(arg_node, torch.fx.Node): transitions_inserted = ( set_arg_node_repr_or_transition( - graph_module, op_node, i, arg_node_repr, transitions_inserted + graph_module, + op_node, + i, + arg_node_repr, + transitions_inserted, + transition_cache, ) or transitions_inserted ) @@ -473,12 +504,14 @@ def set_op_node_tensor_reprs( i, arg_node_repr, transitions_inserted, + transition_cache, ) or transitions_inserted ) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + transition_cache: dict = {} for node in graph_module.graph.nodes: - self.set_op_node_tensor_reprs(graph_module, node) + self.set_op_node_tensor_reprs(graph_module, node, transition_cache) return PassResult(graph_module, True)