Skip to content

Commit 4056395

Browse files
ryan-monroefacebook-github-bot
authored andcommitted
Add FuseConcatPass to eliminate redundant concat ops (#18827)
Summary: Concat (torch.cat) in the Gen2 Executorch ARM/Ethos-U stack is lowered to TOSA CONCAT, which Vela then converts to N x MemoryCopy operations — real DMA data movement on the NPU. This pass eliminates concat operations that can be proven unnecessary at the FX graph level, preventing Vela from generating MemoryCopy ops entirely. Inspired by Espresso's concat elimination techniques (bolt/nn/espresso/transforms/remove_nops.py), three patterns are handled: 1. Single-input concat: cat([x]) is a no-op, replaced with x. 2. Concat-then-slice: if every consumer of cat([a, b, ...]) is a slice_copy that extracts exactly one original input, bypass both. 3. Slice-then-concat: if contiguous slices of the same tensor are concatenated back, the result is the original tensor. Differential Revision: D97667069
1 parent 875f7c8 commit 4056395

4 files changed

Lines changed: 537 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_concat_pass import FuseConcatPass # noqa
105106
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
106107
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
107108
from .fuse_constant_ops_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConcatPass,
101102
FuseConsecutiveConcatShapesPass,
102103
FuseConsecutiveRescalesPass,
103104
FuseConstantArgsPass,
@@ -486,6 +487,7 @@ def _tosa_pipeline(
486487
# Aten -> TOSA transformation passes
487488
self.add_passes(
488489
[
490+
FuseConcatPass(),
489491
RewriteUpsamplePass(),
490492
RewriteConvPass(exported_program),
491493
RewriteMatmulPass(),
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
from typing import Set, Type
10+
11+
import torch.fx
12+
from executorch.backends.arm._passes import ArmPass
13+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def _int_arg(node: torch.fx.Node, index: int, default: int) -> int:
21+
"""Get an integer argument from a node, with a default if missing."""
22+
val = node.args[index] if len(node.args) > index else default
23+
assert isinstance(val, int)
24+
return val
25+
26+
27+
def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]:
28+
"""Extract (dim, start, end, step) from a slice_copy node.
29+
30+
``dim`` is normalized to a positive index. ``end`` is clamped to
31+
``dim_size`` (the size of the source tensor along the slice dimension).
32+
33+
"""
34+
rank = len(get_first_fake_tensor(node).shape)
35+
dim = _int_arg(node, 1, 0)
36+
dim = (dim + rank) % rank
37+
start = _int_arg(node, 2, 0)
38+
end = min(_int_arg(node, 3, dim_size), dim_size)
39+
step = _int_arg(node, 4, 1)
40+
return dim, start, end, step
41+
42+
43+
_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor
44+
45+
46+
def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool:
47+
"""Check that node is a slice_copy on cat_dim with step=1."""
48+
if node.target != _SLICE_OP:
49+
return False
50+
s_dim, _, _, s_step = _slice_params(node, dim_size)
51+
return s_dim == cat_dim and s_step == 1
52+
53+
54+
def _find_slice_replacement(
55+
slice_op: torch.fx.Node,
56+
cat_node: torch.fx.Node,
57+
cat_dim: int,
58+
s_start: int,
59+
s_end: int,
60+
offsets: list[tuple[int, int, torch.fx.Node]],
61+
) -> torch.fx.Node | None:
62+
"""Find a replacement for a slice that consumes a cat output.
63+
64+
``offsets`` maps each concat input to its range in the concatenated
65+
output: [(start, end, input_node), ...] along ``cat_dim``.
66+
67+
Returns the replacement node (exact input match or adjusted sub-slice),
68+
or None if the slice crosses input boundaries.
69+
70+
"""
71+
for o_start, o_end, inp in offsets:
72+
if s_start == o_start and s_end == o_end:
73+
return inp
74+
if s_start >= o_start and s_end <= o_end:
75+
graph = cat_node.graph
76+
with graph.inserting_before(slice_op):
77+
new_slice = graph.call_function(
78+
_SLICE_OP,
79+
(inp, cat_dim, s_start - o_start, s_end - o_start),
80+
)
81+
new_slice.meta = slice_op.meta.copy()
82+
return new_slice
83+
return None
84+
85+
86+
def _find_common_slice_source(
87+
cat_inputs: list | tuple,
88+
cat_dim: int,
89+
dim_size: int,
90+
) -> torch.fx.Node | None:
91+
"""Check all inputs are valid slices of the same source.
92+
93+
Returns the source.
94+
95+
"""
96+
source_node = None
97+
for inp in cat_inputs:
98+
if not isinstance(inp, torch.fx.Node):
99+
return None
100+
if not _is_valid_slice(inp, cat_dim, dim_size):
101+
return None
102+
slice_source = inp.args[0]
103+
if source_node is None:
104+
source_node = slice_source
105+
elif slice_source is not source_node:
106+
return None
107+
assert isinstance(source_node, torch.fx.Node)
108+
return source_node
109+
110+
111+
def _check_contiguous_slices(
112+
cat_inputs: list | tuple,
113+
source_dim_size: int,
114+
) -> tuple[int, int] | None:
115+
"""Check slices are contiguous.
116+
117+
Returns (first_start, last_end) or None.
118+
119+
"""
120+
_, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size)
121+
expected_start = first_start
122+
for inp in cat_inputs:
123+
_, s_start, s_end, _ = _slice_params(inp, source_dim_size)
124+
if s_start != expected_start:
125+
return None
126+
expected_start = s_end
127+
128+
# expected_start is now the end of the last slice
129+
return first_start, expected_start
130+
131+
132+
class FuseConcatPass(ArmPass):
133+
"""Eliminate redundant concat (cat) operations via graph pattern matching.
134+
135+
Inspired by Espresso's concat elimination techniques
136+
(bolt/nn/espresso/transforms/remove_nops.py), this pass recognizes and
137+
removes concat operations that can be proven to produce no useful data
138+
movement. Eliminating these at the FX/TOSA level prevents Vela from
139+
generating MemoryCopy operations on the Ethos-U NPU.
140+
141+
Five patterns are handled:
142+
143+
1. Single-input concat: cat([x], dim) is a no-op; replace with x.
144+
2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is
145+
a slice_copy that extracts exactly one original input, replace it
146+
with the corresponding concat input directly.
147+
3. Slice-then-concat (full): if cat([slice(x, d, s0, e0),
148+
slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous
149+
slices covering the full source dimension), replace with x.
150+
4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a
151+
slice_copy whose range falls entirely within one original input,
152+
replace it with an adjusted slice on that input directly.
153+
5. Slice-then-concat (partial): if contiguous slices of the same tensor
154+
are concatenated but cover only a sub-range of the source dimension,
155+
replace with a single slice on the source.
156+
157+
"""
158+
159+
_passes_required_after: Set[Type[ExportPass]] = set()
160+
161+
cat_ops = {
162+
exir_ops.edge.aten.cat.default,
163+
}
164+
slice_op = _SLICE_OP
165+
166+
def call(self, graph_module: torch.fx.GraphModule):
167+
modified = False
168+
graph = graph_module.graph
169+
170+
for node in list(graph.nodes):
171+
if node.op != "call_function" or node.target not in self.cat_ops:
172+
continue
173+
if node.graph is None:
174+
continue
175+
176+
if self._eliminate_single_input_cat(node):
177+
modified = True
178+
continue
179+
180+
if self._eliminate_cat_then_slice(node):
181+
modified = True
182+
continue
183+
184+
if self._eliminate_slice_then_cat(node):
185+
modified = True
186+
continue
187+
188+
if modified:
189+
graph.eliminate_dead_code()
190+
graph_module.recompile()
191+
graph_module = super().call(graph_module).graph_module
192+
193+
return PassResult(graph_module, modified)
194+
195+
# ------------------------------------------------------------------
196+
# Pattern 1: single-input cat
197+
# ------------------------------------------------------------------
198+
@staticmethod
199+
def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool:
200+
inputs = cat_node.args[0]
201+
if not isinstance(inputs, (list, tuple)) or len(inputs) != 1:
202+
return False
203+
sole_input = inputs[0]
204+
assert isinstance(sole_input, torch.fx.Node)
205+
cat_node.replace_all_uses_with(sole_input)
206+
logger.debug("Eliminated single-input cat: %s", cat_node.name)
207+
return True
208+
209+
# ------------------------------------------------------------------
210+
# Patterns 2 & 4: cat -> slice (exact input or sub-range of input)
211+
# ------------------------------------------------------------------
212+
@staticmethod
213+
def _eliminate_cat_then_slice(
214+
cat_node: torch.fx.Node,
215+
) -> bool:
216+
cat_inputs = cat_node.args[0]
217+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
218+
return False
219+
220+
# if the dim does not exist as an arg, it defaults to '0'
221+
cat_dim = _int_arg(cat_node, 1, 0)
222+
output_rank = len(get_first_fake_tensor(cat_node).shape)
223+
cat_dim = (cat_dim + output_rank) % output_rank
224+
225+
users = list(cat_node.users.keys())
226+
if not users:
227+
return False
228+
229+
# Build the offset map for each concat input along cat_dim.
230+
offsets = []
231+
offset = 0
232+
for inp in cat_inputs:
233+
assert isinstance(inp, torch.fx.Node)
234+
inp_shape = get_first_fake_tensor(inp).shape
235+
size = inp_shape[cat_dim]
236+
offsets.append((offset, offset + size, inp))
237+
offset += size
238+
239+
# Every user must be a slice_copy on the same dim with step=1.
240+
# Collect validated (node, start, end) for replacement below.
241+
validated_slices: list[tuple[torch.fx.Node, int, int]] = []
242+
for slice_op in users:
243+
if not _is_valid_slice(slice_op, cat_dim, offset):
244+
return False
245+
if slice_op.args[0] is not cat_node:
246+
return False
247+
_, s_start, s_end, _ = _slice_params(slice_op, offset)
248+
validated_slices.append((slice_op, s_start, s_end))
249+
250+
# For each user, try exact match (Pattern 2) then sub-range (Pattern 4).
251+
# Users that cross input boundaries are skipped.
252+
replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = []
253+
254+
for slice_op, s_start, s_end in validated_slices:
255+
replacement = _find_slice_replacement(
256+
slice_op, cat_node, cat_dim, s_start, s_end, offsets
257+
)
258+
if replacement is not None:
259+
replacements.append((slice_op, replacement))
260+
261+
if not replacements:
262+
return False
263+
264+
for old_node, new_node in replacements:
265+
old_node.replace_all_uses_with(new_node)
266+
267+
logger.debug(
268+
"Eliminated cat-then-slice pattern: %s (%d slices redirected)",
269+
cat_node.name,
270+
len(replacements),
271+
)
272+
return True
273+
274+
# ------------------------------------------------------------------
275+
# Patterns 3 & 5: slice -> cat (contiguous slices, full or partial)
276+
# ------------------------------------------------------------------
277+
@staticmethod
278+
def _eliminate_slice_then_cat(
279+
cat_node: torch.fx.Node,
280+
) -> bool:
281+
cat_inputs = cat_node.args[0]
282+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
283+
return False
284+
285+
cat_dim = _int_arg(cat_node, 1, 0)
286+
output_rank = len(get_first_fake_tensor(cat_node).shape)
287+
cat_dim = (cat_dim + output_rank) % output_rank
288+
289+
# All inputs must be slice_copy on the same source tensor and dim,
290+
# with step=1.
291+
source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank)
292+
if source_node is None:
293+
return False
294+
295+
source_shape = get_first_fake_tensor(source_node).shape
296+
source_dim_size = source_shape[cat_dim]
297+
298+
# Verify slices are contiguous (but not necessarily starting at 0).
299+
bounds = _check_contiguous_slices(cat_inputs, source_dim_size)
300+
if bounds is None:
301+
return False
302+
first_start, last_end = bounds
303+
304+
# Verify output shape matches expectations.
305+
cat_shape = get_first_fake_tensor(cat_node).shape
306+
307+
if first_start == 0 and last_end == source_dim_size:
308+
# Pattern 3: full coverage — replace with source tensor.
309+
if list(cat_shape) != list(source_shape):
310+
return False
311+
cat_node.replace_all_uses_with(source_node)
312+
logger.debug(
313+
"Eliminated slice-then-cat (full): %s -> %s",
314+
cat_node.name,
315+
source_node.name,
316+
)
317+
else:
318+
# Pattern 5: partial coverage — replace with single slice.
319+
expected_dim_size = last_end - first_start
320+
if cat_shape[cat_dim] != expected_dim_size:
321+
return False
322+
for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)):
323+
if i != cat_dim and cs != ss: # dims must match except for cat_dim
324+
return False
325+
graph = cat_node.graph
326+
with graph.inserting_before(cat_node):
327+
new_slice = graph.call_function(
328+
_SLICE_OP,
329+
(source_node, cat_dim, first_start, last_end),
330+
)
331+
new_slice.meta = cat_node.meta.copy()
332+
cat_node.replace_all_uses_with(new_slice)
333+
logger.debug(
334+
"Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)",
335+
cat_node.name,
336+
source_node.name,
337+
cat_dim,
338+
first_start,
339+
last_end,
340+
)
341+
return True

0 commit comments

Comments
 (0)