diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index c45ed4ea25d..373b2a4d135 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Vulkan compute graph. This annotation is used in later graph passes. node.meta["etvk_tensorref"] = True - # Get the list of node users that do not handle their own prepacking + # Get the list of node users that need a prepack node inserted. This + # includes ops that don't handle their own prepacking, as well as ops + # that do handle their own prepacking but use this constant tensor as + # the primary input (since the op expects the primary input to be a GPU + # tensor, not a TensorRef). nodes_to_replace_input = [] for user in node.users: - if user.op == "call_function" and not handles_own_prepacking(user.target): + if user.op != "call_function": + continue + + if not handles_own_prepacking(user.target): + nodes_to_replace_input.append(user) + continue + + # Most prepacking ops have the primary input at arg 0, but + # embedding is embedding(weight, indices, ...) where the + # primary input (indices) is at arg 1. + primary_arg_idx = 0 + if user.target == exir_ops.edge.aten.embedding.default: + primary_arg_idx = 1 + + if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) if len(nodes_to_replace_input) == 0: