From 2d65b78be475c3165baf9501f2b4e90938a8686e Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 4 Mar 2026 08:29:40 -0800 Subject: [PATCH] [ET-VK] Insert prepack nodes for constant primary inputs of prepacking ops The insert_prepack_nodes pass was skipping prepack node insertion for all constant tensor args of ops with supports_prepacking=True. However, these ops only handle prepacking for weight/bias tensors internally; the primary input tensor is still expected to be a GPU tensor. If the primary input happens to be a constant tensor (serialized as TensorRef), the op throws an exception at runtime. Fix this by detecting the primary input index directly in insert_prepack_nodes. Most prepacking ops have the primary input at arg 0, but embedding uses arg 1 since its signature is embedding(weight, indices, ...). The pass now checks whether a constant tensor is used as the primary input of a prepacking op, and if so, still inserts a prepack node for it. Differential Revision: [D95217949](https://our.internmc.facebook.com/intern/diff/D95217949/) [ghstack-poisoned] --- .../vulkan/_passes/insert_prepack_nodes.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index c45ed4ea25d..373b2a4d135 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Vulkan compute graph. This annotation is used in later graph passes. node.meta["etvk_tensorref"] = True - # Get the list of node users that do not handle their own prepacking + # Get the list of node users that need a prepack node inserted. This + # includes ops that don't handle their own prepacking, as well as ops + # that do handle their own prepacking but use this constant tensor as + # the primary input (since the op expects the primary input to be a GPU + # tensor, not a TensorRef). nodes_to_replace_input = [] for user in node.users: - if user.op == "call_function" and not handles_own_prepacking(user.target): + if user.op != "call_function": + continue + + if not handles_own_prepacking(user.target): + nodes_to_replace_input.append(user) + continue + + # Most prepacking ops have the primary input at arg 0, but + # embedding is embedding(weight, indices, ...) where the + # primary input (indices) is at arg 1. + primary_arg_idx = 0 + if user.target == exir_ops.edge.aten.embedding.default: + primary_arg_idx = 1 + + if node in user.args and user.args.index(node) == primary_arg_idx: nodes_to_replace_input.append(user) if len(nodes_to_replace_input) == 0: