Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,28 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
# Vulkan compute graph. This annotation is used in later graph passes.
node.meta["etvk_tensorref"] = True

# Get the list of node users that do not handle their own prepacking
# Get the list of node users that need a prepack node inserted. This
# includes ops that don't handle their own prepacking, as well as ops
# that do handle their own prepacking but use this constant tensor as
# the primary input (since the op expects the primary input to be a GPU
# tensor, not a TensorRef).
nodes_to_replace_input = []
for user in node.users:
if user.op == "call_function" and not handles_own_prepacking(user.target):
if user.op != "call_function":
continue

if not handles_own_prepacking(user.target):
nodes_to_replace_input.append(user)
continue

# Most prepacking ops have the primary input at arg 0, but
# embedding is embedding(weight, indices, ...) where the
# primary input (indices) is at arg 1.
primary_arg_idx = 0
if user.target == exir_ops.edge.aten.embedding.default:
primary_arg_idx = 1

if node in user.args and user.args.index(node) == primary_arg_idx:
nodes_to_replace_input.append(user)

if len(nodes_to_replace_input) == 0:
Expand Down
Loading