Skip to content

Commit 8ec6e85

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Prefer downstream layout in TagMemoryMetaPass to reduce transitions
Pull Request resolved: #19113 Two changes to the layout assignment pass that together reduce layout transitions by ~89% for transformer-style models (73 → 9 for EdgeTAM ViT-S encoder): 1. BFS instead of DFS for downstream user tracing. The old DFS could exhaust the search budget (64 nodes) on one deep branch before discovering a constraining op on a sibling branch. BFS explores all immediate users at each level first, finding nearby layout-constrained ops (e.g. linear requiring width_packed) more reliably. 2. Prefer downstream consumers' layout over upstream source's layout. Previously, if the upstream source already had a representation (e.g. channels_packed from conv2d), that was applied first and locked in the layout via sync_primary_io_repr before downstream tracing could run. Now, downstream users are traced first to discover what layout they prefer, and the upstream source is only used as a fallback when downstream doesn't constrain. For ViT-style transformers, conv2d (patch embedding) forces channels_packed, which previously propagated through all residual connections via flexible ops (layer_norm, add, mul). With downstream-preferred layout, linear ops' width_packed requirement is discovered first, so the entire transformer stack stays width_packed. Transitions only occur at the conv2d↔transformer boundaries. ghstack-source-id: 373258238 @exported-using-ghexport Differential Revision: [D102360203](https://our.internmc.facebook.com/intern/diff/D102360203/)
1 parent 2084866 commit 8ec6e85

1 file changed

Lines changed: 81 additions & 49 deletions

File tree

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import logging
88
import operator
9+
10+
from collections import deque
911
from typing import Any
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901
332334
search_depth: list[int] | None = None,
333335
) -> utils.TensorRepSet:
334336
"""
335-
For an ambiguous repset, try to constrain the repset by tracing the required
336-
repsets of the users of `origin_node`. The idea is to try to find a representation
337-
that can be used the longest without needing user nodes to insert a transition
338-
for its arguments.
337+
BFS over downstream users to constrain an ambiguous repset. Explores all
338+
immediate users at each level before going deeper, so that nearby constrained
339+
ops (e.g. linear requiring width_packed) are discovered before the search
340+
budget is spent on a single deep branch.
339341
"""
340-
# Optionally limit the total number of nodes explored to improve export
341-
# time. search_depth is a mutable list so that all branches of a fan-out
342-
# share a single counter, preventing exponential blowup.
343342
if self.max_trace_search_depth is not None:
344343
if search_depth is None:
345344
search_depth = [self.max_trace_search_depth]
346-
search_depth[0] -= 1
347-
if search_depth[0] <= 0:
345+
346+
queue: deque[torch.fx.Node] = deque()
347+
queue.append(origin_node)
348+
349+
while queue:
350+
if repset.is_constrained():
348351
return repset
349352

350-
users_to_trace = origin_node.users
353+
if self.max_trace_search_depth is not None:
354+
search_depth[0] -= 1
355+
if search_depth[0] <= 0:
356+
return repset
357+
358+
node = queue.popleft()
359+
360+
users_to_trace = node.users
361+
362+
sync_outs_repr = True
363+
if self.is_valid_op_node(node):
364+
sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr
351365

352-
sync_outs_repr = True
353-
if self.is_valid_op_node(origin_node):
354-
sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr
366+
if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr:
367+
users_to_trace = []
368+
for usage_node in node.users:
369+
if (
370+
usage_node.target == operator.getitem
371+
and usage_node.args[1] == 1
372+
):
373+
users_to_trace.append(usage_node)
355374

356-
if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr:
357-
users_to_trace = []
358-
for usage_node in origin_node.users:
359-
if usage_node.target == operator.getitem and usage_node.args[1] == 1:
360-
users_to_trace.append(usage_node)
375+
for usage_node in users_to_trace:
376+
if repset.is_constrained():
377+
return repset
361378

362-
for usage_node in users_to_trace:
363-
arg_i_in_user = None
364-
for i in range(len(usage_node.args)):
365-
if origin_node == usage_node.args[i]:
366-
arg_i_in_user = i
367-
break
379+
arg_i_in_user = None
380+
for i in range(len(usage_node.args)):
381+
if node == usage_node.args[i]:
382+
arg_i_in_user = i
383+
break
368384

369-
if arg_i_in_user is not None:
370-
repset = self.constrain_repset_with_user(
371-
usage_node, arg_i_in_user, repset, search_depth
385+
if arg_i_in_user is None:
386+
continue
387+
388+
if not self.is_valid_op_node(usage_node):
389+
continue
390+
391+
cur_node_repsets = self.get_node_cached_repsets(usage_node)
392+
req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user)
393+
394+
if not req_arg_repset.any_in_common(repset):
395+
continue
396+
397+
repset = repset.make_intersect(req_arg_repset)
398+
399+
repset_propagates_to_output = (
400+
cur_node_repsets.sync_primary_io_repr
401+
and (
402+
cur_node_repsets.sync_args_repr
403+
or arg_i_in_user == cur_node_repsets.primary_arg_idx
404+
)
372405
)
373406

374-
if repset.is_constrained():
375-
return repset
407+
if repset_propagates_to_output:
408+
queue.append(usage_node)
376409

377410
return repset
378411

379412
def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None:
380413
"""
381414
Attempts to constrain the repset of the argument at index `arg_i` of the op
382-
associated with `op_repsets`. Does this with two stages:
383-
384-
1. First, account for any existing representation that has already been determined
385-
for the argument. If no existing representation has been determined, then use
386-
the output repset of the operator that produces the argument.
387-
2. Then, try to trace through the users of the argument to find a representation
388-
that can be used for as long as possible without needing a transition.
415+
associated with `op_repsets`. Prefers downstream consumers' layout requirements
416+
over the upstream source's existing layout, falling back to the source only when
417+
downstream tracing does not fully constrain the repset.
389418
"""
390-
# If forcing fp16, then try to use texture storage whenever possible. This is
391-
# a temporary stopgap measure until all buffer implementations properly account
392-
# for potential overflow of fp16 representation range when doing math in fp16.
393419
if self.force_fp16:
394420
op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE)
395421

396-
arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i)
397-
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)
398-
399-
arg_repset = op_repsets.get_arg_repset(arg_i)
400-
if arg_repset.is_constrained():
401-
return
402-
422+
# First, trace downstream users to discover what layout they prefer.
403423
arg_node = op_repsets.op_node.args[arg_i]
404-
405424
if isinstance(arg_node, list):
406425
arg_node = arg_node[0]
407426

408-
arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
409-
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)
427+
arg_repset = op_repsets.get_arg_repset(arg_i)
428+
if not arg_repset.is_constrained():
429+
downstream_repset = self.trace_node_users_to_constrain_repset(
430+
arg_node, arg_repset
431+
)
432+
op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset)
433+
434+
# Fall back to the upstream source's existing layout only if downstream
435+
# tracing did not fully constrain the repset.
436+
arg_repset = op_repsets.get_arg_repset(arg_i)
437+
if not arg_repset.is_constrained():
438+
arg_source_repset = self.get_arg_tensor_source_repset(
439+
op_repsets.op_node, arg_i
440+
)
441+
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)
410442

411443
def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
412444
"""

0 commit comments

Comments
 (0)