Skip to content

Commit ca6843d

Browse files
ryan-monroefacebook-github-bot
authored andcommitted
Add FuseConcatPass to eliminate redundant concat ops (pytorch#18827)
Summary: Concat (torch.cat) in the Gen2 Executorch ARM/Ethos-U stack is lowered to TOSA CONCAT, which Vela then converts to N x MemoryCopy operations — real DMA data movement on the NPU. This pass eliminates concat operations that can be proven unnecessary at the FX graph level, preventing Vela from generating MemoryCopy ops entirely. Inspired by Espresso's concat elimination techniques (bolt/nn/espresso/transforms/remove_nops.py), three patterns are handled: 1. Single-input concat: cat([x]) is a no-op, replaced with x. 2. Concat-then-slice: if every consumer of cat([a, b, ...]) is a slice_copy that extracts exactly one original input, bypass both. 3. Slice-then-concat: if contiguous slices of the same tensor are concatenated back, the result is the original tensor. Differential Revision: D97667069
1 parent 875f7c8 commit ca6843d

4 files changed

Lines changed: 525 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
QuantizeClampArgumentsPass,
103103
)
104104
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
105+
from .fuse_concat_pass import FuseConcatPass # noqa
105106
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
106107
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
107108
from .fuse_constant_ops_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecorateFp32toInt32CastingPass,
9999
FoldAndAnnotateQParamsPass,
100100
FuseBatchNorm2dPass,
101+
FuseConcatPass,
101102
FuseConsecutiveConcatShapesPass,
102103
FuseConsecutiveRescalesPass,
103104
FuseConstantArgsPass,
@@ -486,6 +487,7 @@ def _tosa_pipeline(
486487
# Aten -> TOSA transformation passes
487488
self.add_passes(
488489
[
490+
FuseConcatPass(),
489491
RewriteUpsamplePass(),
490492
RewriteConvPass(exported_program),
491493
RewriteMatmulPass(),
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2026 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
from typing import Set, Type
10+
11+
import torch.fx
12+
from executorch.backends.arm._passes import ArmPass
13+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def _int_arg(node: torch.fx.Node, index: int, default: int) -> int:
21+
"""Get an integer argument from a node, with a default if missing."""
22+
val = node.args[index] if len(node.args) > index else default
23+
assert isinstance(val, int)
24+
return val
25+
26+
27+
def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]:
28+
"""Extract (dim, start, end, step) from a slice_copy node.
29+
30+
``dim`` is normalized to a positive index. ``end`` is clamped to
31+
``dim_size`` (the size of the source tensor along the slice dimension).
32+
"""
33+
rank = len(get_first_fake_tensor(node).shape)
34+
dim = _int_arg(node, 1, 0)
35+
dim = (dim + rank) % rank
36+
start = _int_arg(node, 2, 0)
37+
end = min(_int_arg(node, 3, dim_size), dim_size)
38+
step = _int_arg(node, 4, 1)
39+
return dim, start, end, step
40+
41+
42+
_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor
43+
44+
45+
def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool:
46+
"""Check that node is a slice_copy on cat_dim with step=1."""
47+
if node.target != _SLICE_OP:
48+
return False
49+
s_dim, _, _, s_step = _slice_params(node, dim_size)
50+
return s_dim == cat_dim and s_step == 1
51+
52+
53+
def _find_slice_replacement(
54+
slice_op: torch.fx.Node,
55+
cat_node: torch.fx.Node,
56+
cat_dim: int,
57+
s_start: int,
58+
s_end: int,
59+
offsets: list[tuple[int, int, torch.fx.Node]],
60+
) -> torch.fx.Node | None:
61+
"""Find a replacement for a slice that consumes a cat output.
62+
63+
``offsets`` maps each concat input to its range in the concatenated
64+
output: [(start, end, input_node), ...] along ``cat_dim``.
65+
66+
Returns the replacement node (exact input match or adjusted sub-slice),
67+
or None if the slice crosses input boundaries.
68+
"""
69+
for o_start, o_end, inp in offsets:
70+
if s_start == o_start and s_end == o_end:
71+
return inp
72+
if s_start >= o_start and s_end <= o_end:
73+
graph = cat_node.graph
74+
with graph.inserting_before(slice_op):
75+
new_slice = graph.call_function(
76+
_SLICE_OP,
77+
(inp, cat_dim, s_start - o_start, s_end - o_start),
78+
)
79+
new_slice.meta = slice_op.meta.copy()
80+
return new_slice
81+
return None
82+
83+
84+
def _find_common_slice_source(
85+
cat_inputs: list | tuple,
86+
cat_dim: int,
87+
dim_size: int,
88+
) -> torch.fx.Node | None:
89+
"""Check all inputs are valid slices of the same source. Returns the source."""
90+
source_node = None
91+
for inp in cat_inputs:
92+
if not isinstance(inp, torch.fx.Node):
93+
return None
94+
if not _is_valid_slice(inp, cat_dim, dim_size):
95+
return None
96+
slice_source = inp.args[0]
97+
if source_node is None:
98+
source_node = slice_source
99+
elif slice_source is not source_node:
100+
return None
101+
assert isinstance(source_node, torch.fx.Node)
102+
return source_node
103+
104+
105+
def _check_contiguous_slices(
106+
cat_inputs: list | tuple,
107+
source_dim_size: int,
108+
) -> tuple[int, int] | None:
109+
"""Check slices are contiguous. Returns (first_start, last_end) or None."""
110+
_, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size)
111+
expected_start = first_start
112+
for inp in cat_inputs:
113+
_, s_start, s_end, _ = _slice_params(inp, source_dim_size)
114+
if s_start != expected_start:
115+
return None
116+
expected_start = s_end
117+
118+
# expected_start is now the end of the last slice
119+
return first_start, expected_start
120+
121+
122+
class FuseConcatPass(ArmPass):
123+
"""Eliminate redundant concat (cat) operations via graph pattern matching.
124+
125+
Inspired by Espresso's concat elimination techniques
126+
(bolt/nn/espresso/transforms/remove_nops.py), this pass recognizes and
127+
removes concat operations that can be proven to produce no useful data
128+
movement. Eliminating these at the FX/TOSA level prevents Vela from
129+
generating MemoryCopy operations on the Ethos-U NPU.
130+
131+
Five patterns are handled:
132+
133+
1. Single-input concat: cat([x], dim) is a no-op; replace with x.
134+
2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is
135+
a slice_copy that extracts exactly one original input, replace it
136+
with the corresponding concat input directly.
137+
3. Slice-then-concat (full): if cat([slice(x, d, s0, e0),
138+
slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous
139+
slices covering the full source dimension), replace with x.
140+
4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a
141+
slice_copy whose range falls entirely within one original input,
142+
replace it with an adjusted slice on that input directly.
143+
5. Slice-then-concat (partial): if contiguous slices of the same tensor
144+
are concatenated but cover only a sub-range of the source dimension,
145+
replace with a single slice on the source.
146+
"""
147+
148+
_passes_required_after: Set[Type[ExportPass]] = set()
149+
150+
cat_ops = {
151+
exir_ops.edge.aten.cat.default,
152+
}
153+
slice_op = _SLICE_OP
154+
155+
def call(self, graph_module: torch.fx.GraphModule):
156+
modified = False
157+
graph = graph_module.graph
158+
159+
for node in list(graph.nodes):
160+
if node.op != "call_function" or node.target not in self.cat_ops:
161+
continue
162+
if node.graph is None:
163+
continue
164+
165+
if self._eliminate_single_input_cat(node):
166+
modified = True
167+
continue
168+
169+
if self._eliminate_cat_then_slice(node):
170+
modified = True
171+
continue
172+
173+
if self._eliminate_slice_then_cat(node):
174+
modified = True
175+
continue
176+
177+
if modified:
178+
graph.eliminate_dead_code()
179+
graph_module.recompile()
180+
graph_module = super().call(graph_module).graph_module
181+
182+
return PassResult(graph_module, modified)
183+
184+
# ------------------------------------------------------------------
185+
# Pattern 1: single-input cat
186+
# ------------------------------------------------------------------
187+
@staticmethod
188+
def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool:
189+
inputs = cat_node.args[0]
190+
if not isinstance(inputs, (list, tuple)) or len(inputs) != 1:
191+
return False
192+
sole_input = inputs[0]
193+
assert isinstance(sole_input, torch.fx.Node)
194+
cat_node.replace_all_uses_with(sole_input)
195+
logger.debug("Eliminated single-input cat: %s", cat_node.name)
196+
return True
197+
198+
# ------------------------------------------------------------------
199+
# Patterns 2 & 4: cat -> slice (exact input or sub-range of input)
200+
# ------------------------------------------------------------------
201+
@staticmethod
202+
def _eliminate_cat_then_slice(
203+
cat_node: torch.fx.Node,
204+
) -> bool:
205+
cat_inputs = cat_node.args[0]
206+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
207+
return False
208+
209+
# if the dim does not exist as an arg, it defaults to '0'
210+
cat_dim = _int_arg(cat_node, 1, 0)
211+
output_rank = len(get_first_fake_tensor(cat_node).shape)
212+
cat_dim = (cat_dim + output_rank) % output_rank
213+
214+
users = list(cat_node.users.keys())
215+
if not users:
216+
return False
217+
218+
# Build the offset map for each concat input along cat_dim.
219+
offsets = []
220+
offset = 0
221+
for inp in cat_inputs:
222+
inp_shape = get_first_fake_tensor(inp).shape
223+
size = inp_shape[cat_dim]
224+
offsets.append((offset, offset + size, inp))
225+
offset += size
226+
227+
# Every user must be a slice_copy on the same dim with step=1.
228+
# Collect validated (node, start, end) for replacement below.
229+
validated_slices: list[tuple[torch.fx.Node, int, int]] = []
230+
for slice_op in users:
231+
if not _is_valid_slice(slice_op, cat_dim, offset):
232+
return False
233+
if slice_op.args[0] is not cat_node:
234+
return False
235+
_, s_start, s_end, _ = _slice_params(slice_op, offset)
236+
validated_slices.append((slice_op, s_start, s_end))
237+
238+
# For each user, try exact match (Pattern 2) then sub-range (Pattern 4).
239+
# Users that cross input boundaries are skipped.
240+
replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = []
241+
242+
for slice_op, s_start, s_end in validated_slices:
243+
replacement = _find_slice_replacement(
244+
slice_op, cat_node, cat_dim, s_start, s_end, offsets
245+
)
246+
if replacement is not None:
247+
replacements.append((slice_op, replacement))
248+
249+
if not replacements:
250+
return False
251+
252+
for old_node, new_node in replacements:
253+
old_node.replace_all_uses_with(new_node)
254+
255+
logger.debug(
256+
"Eliminated cat-then-slice pattern: %s (%d slices redirected)",
257+
cat_node.name,
258+
len(replacements),
259+
)
260+
return True
261+
262+
# ------------------------------------------------------------------
263+
# Patterns 3 & 5: slice -> cat (contiguous slices, full or partial)
264+
# ------------------------------------------------------------------
265+
@staticmethod
266+
def _eliminate_slice_then_cat(
267+
cat_node: torch.fx.Node,
268+
) -> bool:
269+
cat_inputs = cat_node.args[0]
270+
if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2:
271+
return False
272+
273+
cat_dim = _int_arg(cat_node, 1, 0)
274+
output_rank = len(get_first_fake_tensor(cat_node).shape)
275+
cat_dim = (cat_dim + output_rank) % output_rank
276+
277+
# All inputs must be slice_copy on the same source tensor and dim,
278+
# with step=1.
279+
source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank)
280+
if source_node is None:
281+
return False
282+
283+
source_shape = get_first_fake_tensor(source_node).shape
284+
source_dim_size = source_shape[cat_dim]
285+
286+
# Verify slices are contiguous (but not necessarily starting at 0).
287+
bounds = _check_contiguous_slices(cat_inputs, source_dim_size)
288+
if bounds is None:
289+
return False
290+
first_start, last_end = bounds
291+
292+
# Verify output shape matches expectations.
293+
cat_shape = get_first_fake_tensor(cat_node).shape
294+
295+
if first_start == 0 and last_end == source_dim_size:
296+
# Pattern 3: full coverage — replace with source tensor.
297+
if list(cat_shape) != list(source_shape):
298+
return False
299+
cat_node.replace_all_uses_with(source_node)
300+
logger.debug(
301+
"Eliminated slice-then-cat (full): %s -> %s",
302+
cat_node.name,
303+
source_node.name,
304+
)
305+
else:
306+
# Pattern 5: partial coverage — replace with single slice.
307+
expected_dim_size = last_end - first_start
308+
if cat_shape[cat_dim] != expected_dim_size:
309+
return False
310+
for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)):
311+
if i != cat_dim and cs != ss: # dims must match except for cat_dim
312+
return False
313+
graph = cat_node.graph
314+
with graph.inserting_before(cat_node):
315+
new_slice = graph.call_function(
316+
_SLICE_OP,
317+
(source_node, cat_dim, first_start, last_end),
318+
)
319+
new_slice.meta = cat_node.meta.copy()
320+
cat_node.replace_all_uses_with(new_slice)
321+
logger.debug(
322+
"Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)",
323+
cat_node.name,
324+
source_node.name,
325+
cat_dim,
326+
first_start,
327+
last_end,
328+
)
329+
return True

0 commit comments

Comments
 (0)