|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# Copyright 2026 Arm Limited and/or its affiliates. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +import logging |
| 9 | +from typing import Set, Type |
| 10 | + |
| 11 | +import torch.fx |
| 12 | +from executorch.backends.arm._passes import ArmPass |
| 13 | +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor |
| 14 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 15 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +def _int_arg(node: torch.fx.Node, index: int, default: int) -> int: |
| 21 | + """Get an integer argument from a node, with a default if missing.""" |
| 22 | + val = node.args[index] if len(node.args) > index else default |
| 23 | + assert isinstance(val, int) |
| 24 | + return val |
| 25 | + |
| 26 | + |
| 27 | +def _slice_params(node: torch.fx.Node, dim_size: int) -> tuple[int, int, int, int]: |
| 28 | + """Extract (dim, start, end, step) from a slice_copy node. |
| 29 | +
|
| 30 | + ``dim`` is normalized to a positive index. ``end`` is clamped to |
| 31 | + ``dim_size`` (the size of the source tensor along the slice dimension). |
| 32 | +
|
| 33 | + """ |
| 34 | + rank = len(get_first_fake_tensor(node).shape) |
| 35 | + dim = _int_arg(node, 1, 0) |
| 36 | + dim = (dim + rank) % rank |
| 37 | + start = _int_arg(node, 2, 0) |
| 38 | + end = min(_int_arg(node, 3, dim_size), dim_size) |
| 39 | + step = _int_arg(node, 4, 1) |
| 40 | + return dim, start, end, step |
| 41 | + |
| 42 | + |
| 43 | +_SLICE_OP = exir_ops.edge.aten.slice_copy.Tensor |
| 44 | + |
| 45 | + |
| 46 | +def _is_valid_slice(node: torch.fx.Node, cat_dim: int, dim_size: int) -> bool: |
| 47 | + """Check that node is a slice_copy on cat_dim with step=1.""" |
| 48 | + if node.target != _SLICE_OP: |
| 49 | + return False |
| 50 | + s_dim, _, _, s_step = _slice_params(node, dim_size) |
| 51 | + return s_dim == cat_dim and s_step == 1 |
| 52 | + |
| 53 | + |
| 54 | +def _find_slice_replacement( |
| 55 | + slice_op: torch.fx.Node, |
| 56 | + cat_node: torch.fx.Node, |
| 57 | + cat_dim: int, |
| 58 | + s_start: int, |
| 59 | + s_end: int, |
| 60 | + offsets: list[tuple[int, int, torch.fx.Node]], |
| 61 | +) -> torch.fx.Node | None: |
| 62 | + """Find a replacement for a slice that consumes a cat output. |
| 63 | +
|
| 64 | + ``offsets`` maps each concat input to its range in the concatenated |
| 65 | + output: [(start, end, input_node), ...] along ``cat_dim``. |
| 66 | +
|
| 67 | + Returns the replacement node (exact input match or adjusted sub-slice), |
| 68 | + or None if the slice crosses input boundaries. |
| 69 | +
|
| 70 | + """ |
| 71 | + for o_start, o_end, inp in offsets: |
| 72 | + if s_start == o_start and s_end == o_end: |
| 73 | + return inp |
| 74 | + if s_start >= o_start and s_end <= o_end: |
| 75 | + graph = cat_node.graph |
| 76 | + with graph.inserting_before(slice_op): |
| 77 | + new_slice = graph.call_function( |
| 78 | + _SLICE_OP, |
| 79 | + (inp, cat_dim, s_start - o_start, s_end - o_start), |
| 80 | + ) |
| 81 | + new_slice.meta = slice_op.meta.copy() |
| 82 | + return new_slice |
| 83 | + return None |
| 84 | + |
| 85 | + |
| 86 | +def _find_common_slice_source( |
| 87 | + cat_inputs: list | tuple, |
| 88 | + cat_dim: int, |
| 89 | + dim_size: int, |
| 90 | +) -> torch.fx.Node | None: |
| 91 | + """Check all inputs are valid slices of the same source. |
| 92 | +
|
| 93 | + Returns the source. |
| 94 | +
|
| 95 | + """ |
| 96 | + source_node = None |
| 97 | + for inp in cat_inputs: |
| 98 | + if not isinstance(inp, torch.fx.Node): |
| 99 | + return None |
| 100 | + if not _is_valid_slice(inp, cat_dim, dim_size): |
| 101 | + return None |
| 102 | + slice_source = inp.args[0] |
| 103 | + if source_node is None: |
| 104 | + source_node = slice_source |
| 105 | + elif slice_source is not source_node: |
| 106 | + return None |
| 107 | + assert isinstance(source_node, torch.fx.Node) |
| 108 | + return source_node |
| 109 | + |
| 110 | + |
| 111 | +def _check_contiguous_slices( |
| 112 | + cat_inputs: list | tuple, |
| 113 | + source_dim_size: int, |
| 114 | +) -> tuple[int, int] | None: |
| 115 | + """Check slices are contiguous. |
| 116 | +
|
| 117 | + Returns (first_start, last_end) or None. |
| 118 | +
|
| 119 | + """ |
| 120 | + _, first_start, _, _ = _slice_params(cat_inputs[0], source_dim_size) |
| 121 | + expected_start = first_start |
| 122 | + for inp in cat_inputs: |
| 123 | + _, s_start, s_end, _ = _slice_params(inp, source_dim_size) |
| 124 | + if s_start != expected_start: |
| 125 | + return None |
| 126 | + expected_start = s_end |
| 127 | + |
| 128 | + # expected_start is now the end of the last slice |
| 129 | + return first_start, expected_start |
| 130 | + |
| 131 | + |
| 132 | +class FuseConcatPass(ArmPass): |
| 133 | + """Eliminate redundant concat (cat) operations via graph pattern matching. |
| 134 | +
|
| 135 | + This pass recognizes and removes concat operations that can be proven to |
| 136 | + produce no useful data movement. Eliminating these at the FX/TOSA level |
| 137 | + prevents Vela from generating MemoryCopy operations on the Ethos-U NPU. |
| 138 | +
|
| 139 | + Five patterns are handled: |
| 140 | +
|
| 141 | + 1. Single-input concat: cat([x], dim) is a no-op; replace with x. |
| 142 | + 2. Concat-then-slice (exact): if a consumer of cat([a, b, ...], dim) is |
| 143 | + a slice_copy that extracts exactly one original input, replace it |
| 144 | + with the corresponding concat input directly. |
| 145 | + 3. Slice-then-concat (full): if cat([slice(x, d, s0, e0), |
| 146 | + slice(x, d, s1, e1), ...], dim) reconstructs x exactly (contiguous |
| 147 | + slices covering the full source dimension), replace with x. |
| 148 | + 4. Concat-then-sub-slice: if a consumer of cat([a, b, ...], dim) is a |
| 149 | + slice_copy whose range falls entirely within one original input, |
| 150 | + replace it with an adjusted slice on that input directly. |
| 151 | + 5. Slice-then-concat (partial): if contiguous slices of the same tensor |
| 152 | + are concatenated but cover only a sub-range of the source dimension, |
| 153 | + replace with a single slice on the source. |
| 154 | +
|
| 155 | + """ |
| 156 | + |
| 157 | + _passes_required_after: Set[Type[ExportPass]] = set() |
| 158 | + |
| 159 | + cat_ops = { |
| 160 | + exir_ops.edge.aten.cat.default, |
| 161 | + } |
| 162 | + slice_op = _SLICE_OP |
| 163 | + |
| 164 | + def call(self, graph_module: torch.fx.GraphModule): |
| 165 | + modified = False |
| 166 | + graph = graph_module.graph |
| 167 | + |
| 168 | + for node in list(graph.nodes): |
| 169 | + if node.op != "call_function" or node.target not in self.cat_ops: |
| 170 | + continue |
| 171 | + if node.graph is None: |
| 172 | + continue |
| 173 | + |
| 174 | + if self._eliminate_single_input_cat(node): |
| 175 | + modified = True |
| 176 | + continue |
| 177 | + |
| 178 | + if self._eliminate_cat_then_slice(node): |
| 179 | + modified = True |
| 180 | + continue |
| 181 | + |
| 182 | + if self._eliminate_slice_then_cat(node): |
| 183 | + modified = True |
| 184 | + continue |
| 185 | + |
| 186 | + if modified: |
| 187 | + graph.eliminate_dead_code() |
| 188 | + graph_module.recompile() |
| 189 | + graph_module = super().call(graph_module).graph_module |
| 190 | + |
| 191 | + return PassResult(graph_module, modified) |
| 192 | + |
| 193 | + # ------------------------------------------------------------------ |
| 194 | + # Pattern 1: single-input cat |
| 195 | + # ------------------------------------------------------------------ |
| 196 | + @staticmethod |
| 197 | + def _eliminate_single_input_cat(cat_node: torch.fx.Node) -> bool: |
| 198 | + inputs = cat_node.args[0] |
| 199 | + if not isinstance(inputs, (list, tuple)) or len(inputs) != 1: |
| 200 | + return False |
| 201 | + sole_input = inputs[0] |
| 202 | + assert isinstance(sole_input, torch.fx.Node) |
| 203 | + cat_node.replace_all_uses_with(sole_input) |
| 204 | + logger.debug("Eliminated single-input cat: %s", cat_node.name) |
| 205 | + return True |
| 206 | + |
| 207 | + # ------------------------------------------------------------------ |
| 208 | + # Patterns 2 & 4: cat -> slice (exact input or sub-range of input) |
| 209 | + # ------------------------------------------------------------------ |
| 210 | + @staticmethod |
| 211 | + def _eliminate_cat_then_slice( |
| 212 | + cat_node: torch.fx.Node, |
| 213 | + ) -> bool: |
| 214 | + cat_inputs = cat_node.args[0] |
| 215 | + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: |
| 216 | + return False |
| 217 | + |
| 218 | + # if the dim does not exist as an arg, it defaults to '0' |
| 219 | + cat_dim = _int_arg(cat_node, 1, 0) |
| 220 | + output_rank = len(get_first_fake_tensor(cat_node).shape) |
| 221 | + cat_dim = (cat_dim + output_rank) % output_rank |
| 222 | + |
| 223 | + users = list(cat_node.users.keys()) |
| 224 | + if not users: |
| 225 | + return False |
| 226 | + |
| 227 | + # Build the offset map for each concat input along cat_dim. |
| 228 | + offsets = [] |
| 229 | + offset = 0 |
| 230 | + for inp in cat_inputs: |
| 231 | + assert isinstance(inp, torch.fx.Node) |
| 232 | + inp_shape = get_first_fake_tensor(inp).shape |
| 233 | + size = inp_shape[cat_dim] |
| 234 | + offsets.append((offset, offset + size, inp)) |
| 235 | + offset += size |
| 236 | + |
| 237 | + # Every user must be a slice_copy on the same dim with step=1. |
| 238 | + # Collect validated (node, start, end) for replacement below. |
| 239 | + validated_slices: list[tuple[torch.fx.Node, int, int]] = [] |
| 240 | + for slice_op in users: |
| 241 | + if not _is_valid_slice(slice_op, cat_dim, offset): |
| 242 | + return False |
| 243 | + if slice_op.args[0] is not cat_node: |
| 244 | + return False |
| 245 | + _, s_start, s_end, _ = _slice_params(slice_op, offset) |
| 246 | + validated_slices.append((slice_op, s_start, s_end)) |
| 247 | + |
| 248 | + # For each user, try exact match (Pattern 2) then sub-range (Pattern 4). |
| 249 | + # Users that cross input boundaries are skipped. |
| 250 | + replacements: list[tuple[torch.fx.Node, torch.fx.Node]] = [] |
| 251 | + |
| 252 | + for slice_op, s_start, s_end in validated_slices: |
| 253 | + replacement = _find_slice_replacement( |
| 254 | + slice_op, cat_node, cat_dim, s_start, s_end, offsets |
| 255 | + ) |
| 256 | + if replacement is not None: |
| 257 | + replacements.append((slice_op, replacement)) |
| 258 | + |
| 259 | + if not replacements: |
| 260 | + return False |
| 261 | + |
| 262 | + for old_node, new_node in replacements: |
| 263 | + old_node.replace_all_uses_with(new_node) |
| 264 | + |
| 265 | + logger.debug( |
| 266 | + "Eliminated cat-then-slice pattern: %s (%d slices redirected)", |
| 267 | + cat_node.name, |
| 268 | + len(replacements), |
| 269 | + ) |
| 270 | + return True |
| 271 | + |
| 272 | + # ------------------------------------------------------------------ |
| 273 | + # Patterns 3 & 5: slice -> cat (contiguous slices, full or partial) |
| 274 | + # ------------------------------------------------------------------ |
| 275 | + @staticmethod |
| 276 | + def _eliminate_slice_then_cat( |
| 277 | + cat_node: torch.fx.Node, |
| 278 | + ) -> bool: |
| 279 | + cat_inputs = cat_node.args[0] |
| 280 | + if not isinstance(cat_inputs, (list, tuple)) or len(cat_inputs) < 2: |
| 281 | + return False |
| 282 | + |
| 283 | + cat_dim = _int_arg(cat_node, 1, 0) |
| 284 | + output_rank = len(get_first_fake_tensor(cat_node).shape) |
| 285 | + cat_dim = (cat_dim + output_rank) % output_rank |
| 286 | + |
| 287 | + # All inputs must be slice_copy on the same source tensor and dim, |
| 288 | + # with step=1. |
| 289 | + source_node = _find_common_slice_source(cat_inputs, cat_dim, output_rank) |
| 290 | + if source_node is None: |
| 291 | + return False |
| 292 | + |
| 293 | + source_shape = get_first_fake_tensor(source_node).shape |
| 294 | + source_dim_size = source_shape[cat_dim] |
| 295 | + |
| 296 | + # Verify slices are contiguous (but not necessarily starting at 0). |
| 297 | + bounds = _check_contiguous_slices(cat_inputs, source_dim_size) |
| 298 | + if bounds is None: |
| 299 | + return False |
| 300 | + first_start, last_end = bounds |
| 301 | + |
| 302 | + # Verify output shape matches expectations. |
| 303 | + cat_shape = get_first_fake_tensor(cat_node).shape |
| 304 | + |
| 305 | + if first_start == 0 and last_end == source_dim_size: |
| 306 | + # Pattern 3: full coverage — replace with source tensor. |
| 307 | + if list(cat_shape) != list(source_shape): |
| 308 | + return False |
| 309 | + cat_node.replace_all_uses_with(source_node) |
| 310 | + logger.debug( |
| 311 | + "Eliminated slice-then-cat (full): %s -> %s", |
| 312 | + cat_node.name, |
| 313 | + source_node.name, |
| 314 | + ) |
| 315 | + else: |
| 316 | + # Pattern 5: partial coverage — replace with single slice. |
| 317 | + expected_dim_size = last_end - first_start |
| 318 | + if cat_shape[cat_dim] != expected_dim_size: |
| 319 | + return False |
| 320 | + for i, (cs, ss) in enumerate(zip(cat_shape, source_shape)): |
| 321 | + if i != cat_dim and cs != ss: # dims must match except for cat_dim |
| 322 | + return False |
| 323 | + graph = cat_node.graph |
| 324 | + with graph.inserting_before(cat_node): |
| 325 | + new_slice = graph.call_function( |
| 326 | + _SLICE_OP, |
| 327 | + (source_node, cat_dim, first_start, last_end), |
| 328 | + ) |
| 329 | + new_slice.meta = cat_node.meta.copy() |
| 330 | + cat_node.replace_all_uses_with(new_slice) |
| 331 | + logger.debug( |
| 332 | + "Eliminated slice-then-cat (partial): %s -> slice(%s, %d, %d:%d)", |
| 333 | + cat_node.name, |
| 334 | + source_node.name, |
| 335 | + cat_dim, |
| 336 | + first_start, |
| 337 | + last_end, |
| 338 | + ) |
| 339 | + return True |
0 commit comments