Skip to content

Commit b3afe9f

Browse files
ryan-monroefacebook-github-bot
authored andcommitted
Add FuseConcatPass to eliminate redundant concat ops (pytorch#18827)
Summary: Concat (torch.cat) in the Gen2 Executorch ARM/Ethos-U stack is lowered to TOSA CONCAT, which Vela then converts to N x MemoryCopy operations — real DMA data movement on the NPU. This pass eliminates concat operations that can be proven unnecessary at the FX graph level, preventing Vela from generating MemoryCopy ops entirely. Inspired by Espresso's concat elimination techniques (bolt/nn/espresso/transforms/remove_nops.py), three patterns are handled: 1. Single-input concat: cat([x]) is a no-op, replaced with x. 2. Concat-then-slice: if every consumer of cat([a, b, ...]) is a slice_copy that extracts exactly one original input, bypass both. 3. Slice-then-concat: if contiguous slices of the same tensor are concatenated back, the result is the original tensor. Differential Revision: D97667069
1 parent cb94506 commit b3afe9f

4 files changed

Lines changed: 535 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
QuantizeClampArgumentsPass,
107107
)
108108
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
109+
from .fuse_concat_pass import FuseConcatPass # noqa
109110
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
110111
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
111112
from .fuse_constant_ops_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
EnsureUniqueOutputNodesPass,
102102
FoldAndAnnotateQParamsPass,
103103
FuseBatchNorm2dPass,
104+
FuseConcatPass,
104105
FuseConsecutiveConcatShapesPass,
105106
FuseConsecutiveRescalesPass,
106107
FuseConstantArgsPass,
@@ -528,6 +529,7 @@ def _tosa_pipeline(
528529
# Aten -> TOSA transformation passes
529530
self.add_passes(
530531
[
532+
FuseConcatPass(),
531533
RewriteUpsamplePass(),
532534
RewriteMaxPool2dPass(),
533535
RewriteConvPass(exported_program),
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
from typing import Set, Type
10+
11+
import torch.fx
12+
from executorch.backends.arm._passes import ArmPass
13+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def _int_arg(node: torch.fx.Node, index: int, default: int) -> int:
21+
"""Get an integer argument from a node, with a default if missing."""
22+
val = node.args[index] if len(node.args) > index else default
23+
assert isinstance(val, int)
24+
return val
25+
26+
27+
def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]:
28+
"""Extract (dim, start, end, step) from a slice_copy node.
29+
30+
``dim`` is normalized to a positive index. ``end`` is clamped to
31+
``dim_size`` (the size of the source tensor along the slice dimension).
32+
33+
"""
34+
rank = len(get_first_fake_tensor(node).shape)
35+
dim = _int_arg(node, 1, 0)
36+
dim = (dim + rank) % rank
37+
start = _int_arg(node, 2, 0)
38+
end = min(_int_arg(node, 3, dim_size), dim_size)
39+
step = _int_arg(node, 4, 1)
40+
return dim, start, end, step
41+
42+
43+
_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor
44+
45+
46+
def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool:
47+
"""Check that node is a slice_copy on cat_dim with step=1."""
48+
if node.target != _SLICE_OP:
49+
return False
50+
s_dim, _, _, s_step = _slice_params(node, dim_size)
51+
return s_dim == cat_dim and s_step == 1
52+
53+
54+
def _find_slice_replacement(
55+
slice_op: torch.fx.Node,
56+
cat_node: torch.fx.Node,
57+
cat_dim: int,
58+
s_start: int,
59+
s_end: int,
60+
offsets: list[tuple[int, int, torch.fx.Node]],
61+
) -> torch.fx.Node | None:
62+
"""Find a replacement for a slice that consumes a cat output.
63+
64+
``offsets`` maps each concat input to its range in the concatenated
65+
output: [(start, end, input_node), ...] along ``cat_dim``.
66+
67+
Returns the replacement node (exact input match or adjusted sub-slice),
68+
or None if the slice crosses input boundaries.
69+
70+
"""
71+
for o_start, o_end, inp in offsets:
72+
if s_start == o_start and s_end == o_end:
73+
return inp
74+
if s_start >= o_start and s_end <= o_end:
75+
graph = cat_node.graph
76+
with graph.inserting_before(slice_op):
77+
new_slice = graph.call_function(
78+
_SLICE_OP,
79+
(inp, cat_dim, s_start - o_start, s_end - o_start),
80+
)
81+
new_slice.meta = slice_op.meta.copy()
82+
return new_slice
83+
return None
84+
85+
86+
def _find_common_slice_source(
87+
cat_inputs: list | tuple,
88+
cat_dim: int,
89+
dim_size: int,
90+
) -> torch.fx.Node | None:
91+
"""Check all inputs are valid slices of the same source.
92+
93+
Returns the source.
94+
95+
"""
96+
source_node = None
97+
for inp in cat_inputs:
98+
if not isinstance(inp, torch.fx.Node):
99+
return None
100+
if not _is_valid_slice(inp, cat_dim, dim_size):
101+
return None
102+
slice_source = inp.args[0]
103+
if source_node is None:
104+
source_node = slice_source
105+
elif slice_source is not source_node:
106+
return None
107+
assert isinstance(source_node, torch.fx.Node)
108+
return source_node
109+
110+
111+
def _check_contiguous_slices(
112+
cat_inputs: list | tuple,
113+
source_dim_size: int,
114+
) -> tuple[int, int] | None:
115+
"""Check slices are contiguous.
116+
117+
Returns (first_start, last_end) or None.
118+
119+
"""
120+
_, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size)
121+
expected_start = first_start
122+
for inp in cat_inputs:
123+
_, s_start, s_end, _ = _slice_params(inp, source_dim_size)
124+
if s_start != expected_start:
125+
return None
126+
expected_start = s_end
127+
128+
# expected_start is now the end of the last slice
129+
return first_start, expected_start
130+
131+
132+
class FuseConcatPass(ArmPass):
133+
"""Eliminate redundant concat (cat) operations via graph pattern matching.
134+
135+
This pass recognizes and removes concat operations that can be proven to
136+
produce no useful data movement. Eliminating these at the FX/TOSA level
137+
prevents Vela from generating MemoryCopy operations on the Ethos-U NPU.
138+
139+
Five patterns are handled:
140+
141+
1. Single-input concat: cat([x], dim) is a no-op; replace with x.
142+
2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is
143+
a slice_copy that extracts exactly one original input, replace it
144+
with the corresponding concat input directly.
145+
3. Slice-then-concat (full): if cat([slice(x, d, s0, e0),
146+
slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous
147+
slices covering the full source dimension), replace with x.
148+
4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a
149+
slice_copy whose range falls entirely within one original input,
150+
replace it with an adjusted slice on that input directly.
151+
5. Slice-then-concat (partial): if contiguous slices of the same tensor
152+
are concatenated but cover only a sub-range of the source dimension,
153+
replace with a single slice on the source.
154+
155+
"""
156+
157+
_passes_required_after: Set[Type[ExportPass]] = set()
158+
159+
cat_ops = {
160+
exir_ops.edge.aten.cat.default,
161+
}
162+
slice_op = _SLICE_OP
163+
164+
def call(self, graph_module: torch.fx.GraphModule):
165+
modified = False
166+
graph = graph_module.graph
167+
168+
for node in list(graph.nodes):
169+
if node.op != "call_function" or node.target not in self.cat_ops:
170+
continue
171+
if node.graph is None:
172+
continue
173+
174+
if self._eliminate_single_input_cat(node):
175+
modified = True
176+
continue
177+
178+
if self._eliminate_cat_then_slice(node):
179+
modified = True
180+
continue
181+
182+
if self._eliminate_slice_then_cat(node):
183+
modified = True
184+
continue
185+
186+
if modified:
187+
graph.eliminate_dead_code()
188+
graph_module.recompile()
189+
graph_module = super().call(graph_module).graph_module
190+
191+
return PassResult(graph_module, modified)
192+
193+
# ------------------------------------------------------------------
194+
# Pattern 1: single-input cat
195+
# ------------------------------------------------------------------
196+
@staticmethod
197+
def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool:
198+
inputs = cat_node.args[0]
199+
if not isinstance(inputs, (list, tuple)) or len(inputs) != 1:
200+
return False
201+
sole_input = inputs[0]
202+
assert isinstance(sole_input, torch.fx.Node)
203+
cat_node.replace_all_uses_with(sole_input)
204+
logger.debug("Eliminated single-input cat: %s", cat_node.name)
205+
return True
206+
207+
# ------------------------------------------------------------------
208+
# Patterns 2 & 4: cat -> slice (exact input or sub-range of input)
209+
# ------------------------------------------------------------------
210+
@staticmethod
211+
def _eliminate_cat_then_slice(
212+
cat_node: torch.fx.Node,
213+
) -> bool:
214+
cat_inputs = cat_node.args[0]
215+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
216+
return False
217+
218+
# if the dim does not exist as an arg, it defaults to '0'
219+
cat_dim = _int_arg(cat_node, 1, 0)
220+
output_rank = len(get_first_fake_tensor(cat_node).shape)
221+
cat_dim = (cat_dim + output_rank) % output_rank
222+
223+
users = list(cat_node.users.keys())
224+
if not users:
225+
return False
226+
227+
# Build the offset map for each concat input along cat_dim.
228+
offsets = []
229+
offset = 0
230+
for inp in cat_inputs:
231+
assert isinstance(inp, torch.fx.Node)
232+
inp_shape = get_first_fake_tensor(inp).shape
233+
size = inp_shape[cat_dim]
234+
offsets.append((offset, offset + size, inp))
235+
offset += size
236+
237+
# Every user must be a slice_copy on the same dim with step=1.
238+
# Collect validated (node, start, end) for replacement below.
239+
validated_slices: list[tuple[torch.fx.Node, int, int]] = []
240+
for slice_op in users:
241+
if not _is_valid_slice(slice_op, cat_dim, offset):
242+
return False
243+
if slice_op.args[0] is not cat_node:
244+
return False
245+
_, s_start, s_end, _ = _slice_params(slice_op, offset)
246+
validated_slices.append((slice_op, s_start, s_end))
247+
248+
# For each user, try exact match (Pattern 2) then sub-range (Pattern 4).
249+
# Users that cross input boundaries are skipped.
250+
replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = []
251+
252+
for slice_op, s_start, s_end in validated_slices:
253+
replacement = _find_slice_replacement(
254+
slice_op, cat_node, cat_dim, s_start, s_end, offsets
255+
)
256+
if replacement is not None:
257+
replacements.append((slice_op, replacement))
258+
259+
if not replacements:
260+
return False
261+
262+
for old_node, new_node in replacements:
263+
old_node.replace_all_uses_with(new_node)
264+
265+
logger.debug(
266+
"Eliminated cat-then-slice pattern: %s (%d slices redirected)",
267+
cat_node.name,
268+
len(replacements),
269+
)
270+
return True
271+
272+
# ------------------------------------------------------------------
273+
# Patterns 3 & 5: slice -> cat (contiguous slices, full or partial)
274+
# ------------------------------------------------------------------
275+
@staticmethod
276+
def _eliminate_slice_then_cat(
277+
cat_node: torch.fx.Node,
278+
) -> bool:
279+
cat_inputs = cat_node.args[0]
280+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
281+
return False
282+
283+
cat_dim = _int_arg(cat_node, 1, 0)
284+
output_rank = len(get_first_fake_tensor(cat_node).shape)
285+
cat_dim = (cat_dim + output_rank) % output_rank
286+
287+
# All inputs must be slice_copy on the same source tensor and dim,
288+
# with step=1.
289+
source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank)
290+
if source_node is None:
291+
return False
292+
293+
source_shape = get_first_fake_tensor(source_node).shape
294+
source_dim_size = source_shape[cat_dim]
295+
296+
# Verify slices are contiguous (but not necessarily starting at 0).
297+
bounds = _check_contiguous_slices(cat_inputs, source_dim_size)
298+
if bounds is None:
299+
return False
300+
first_start, last_end = bounds
301+
302+
# Verify output shape matches expectations.
303+
cat_shape = get_first_fake_tensor(cat_node).shape
304+
305+
if first_start == 0 and last_end == source_dim_size:
306+
# Pattern 3: full coverage — replace with source tensor.
307+
if list(cat_shape) != list(source_shape):
308+
return False
309+
cat_node.replace_all_uses_with(source_node)
310+
logger.debug(
311+
"Eliminated slice-then-cat (full): %s -> %s",
312+
cat_node.name,
313+
source_node.name,
314+
)
315+
else:
316+
# Pattern 5: partial coverage — replace with single slice.
317+
expected_dim_size = last_end - first_start
318+
if cat_shape[cat_dim] != expected_dim_size:
319+
return False
320+
for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)):
321+
if i != cat_dim and cs != ss: # dims must match except for cat_dim
322+
return False
323+
graph = cat_node.graph
324+
with graph.inserting_before(cat_node):
325+
new_slice = graph.call_function(
326+
_SLICE_OP,
327+
(source_node, cat_dim, first_start, last_end),
328+
)
329+
new_slice.meta = cat_node.meta.copy()
330+
cat_node.replace_all_uses_with(new_slice)
331+
logger.debug(
332+
"Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)",
333+
cat_node.name,
334+
source_node.name,
335+
cat_dim,
336+
first_start,
337+
last_end,
338+
)
339+
return True

0 commit comments

Comments
 (0)