Skip to content

Commit 5cc8286

Browse files
ryan-monroefacebook-github-bot
authored andcommitted
Add FuseConcatPass to eliminate redundant concat ops (#18827)
Summary: Adds an FX-level pass that eliminates concat ops which can be proven structurally redundant before the TOSA backend / Vela compiler ever sees them. In the Gen2 Executorch ARM / Ethos-U stack, `torch.cat` lowers to TOSA `CONCAT`, which Vela converts to N `MemoryCopy` ops — real DMA on the NPU. Catching the obvious cases up front keeps the TOSA flatbuffer fed to Vela smaller, keeps debug graphs honest, and provides defensive coverage on TOSA targets where Vela's own scheduler doesn't run (e.g., the VGF backend). Five rewrite patterns are handled (inspired by Espresso's `bolt/nn/espresso/transforms/remove_nops.py`): 1. **Single-input concat**: `cat([x], dim) ≡ x` — replace cat with x. 2. **Concat-then-slice (exact)**: `cat([a, b, ...], dim)` feeding a `slice_copy` that extracts exactly one original input — replace the slice with the corresponding cat input directly. 3. **Slice-then-concat (full)**: `cat([slice(x, d, s0, e0), slice(x, d, s1, e1), ...], dim)` reconstructing x exactly (contiguous slices covering the full source dimension) — replace cat with x. 4. **Concat-then-sub-slice**: a `slice_copy` whose range falls entirely within one cat input — replace with an adjusted slice on that input directly. 5. **Slice-then-concat (partial)**: contiguous slices of the same tensor concatenated back but covering only a sub-range of the source — replace with a single slice on the source. ## Empirical impact across the production EMG model fleet Measured by running every `frl/ctrl/torchstream/torchstream/pt2/tests/test_emg_lowering_*` quantize+lower test with FuseConcatPass instrumented to log per-call counters, then comparing against the same target with the pass commented out in `arm_pass_manager.py`. All 8 model targets pass under both configurations. | Model | Cats scanned | Eliminated | Pattern fired | | --- | --- | --- | --- | | cascade_classifier | 5 | **3 (60%)** | single-input | | mux_fusion | 8 | **3 (38%)** | single-input | | combined_control | 11 | **3 (27%)** | single-input | | cascade_detector | 14 | 0 | — | | cascade_hw_classifier | 12 | 0 | — | | handwriting | 106 | 0 | — | | wake | 11 | 0 | — | | auth | 6 | 0 | — | | **Total** | **173** | **9** | all single-input | ## Two findings worth highlighting **Patterns 2–5 (the slice-related rewrites) never matched on any production EMG model.** PyTorch's Aten lowering on this fleet doesn't produce the cat↔slice algebra these patterns target. They remain useful for non-EMG TOSA workloads — and for the VGF backend where Vela's own optimizer doesn't run — but on the current EMG production set they are unexercised. **Vela already folds single-input cats during compilation.** A before/after measurement on cascade_classifier (the model with the highest hit rate, 3/5 cats eliminated) shows Vela emits the same 9 `MemoryCopy` ops and consumes the same 481,339 NPU cycles either way. The eliminated cats reappear in the Vela operator table as `Reshape → MemoryCopy` instead of `Concat → MemoryCopy`. Total NPU runtime is unchanged. Pre-Vela artifacts do shrink (TOSA flatbuffer −16 KB / −0.68%, peak staging −1.5 KB / −0.45%), but post-Vela on-device performance is identical. ## Net effect This pass is value-additive even where it doesn't move NPU cycles: - Cleaner TOSA fed into Vela (~16 KB smaller per cascade_classifier instance). - Slightly tighter peak staging during Vela scheduling (~1.5 KB). - Defensive coverage for TOSA-only targets without a Vela-grade scheduler (notably the VGF / Vulkan path). - More truthful FX / EXIR debug graphs — concats that were genuinely no-ops no longer show up in `model-explorer`, `delegation_metadata.json`, or the lowered graph dumps. It does **not** produce measurable NPU cycle savings on the current EMG production fleet. The patterns that would have produced real Vela savings (cat↔slice algebra) don't appear in these models. Authored with Claude. Differential Revision: D97667069
1 parent ada8e35 commit 5cc8286

4 files changed

Lines changed: 872 additions & 0 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
QuantizeClampArgumentsPass,
104104
)
105105
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
106+
from .fuse_concat_pass import FuseConcatPass # noqa
106107
from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa
107108
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
108109
from .fuse_constant_ops_pass import ( # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
EnsureUniqueOutputNodesPass,
101101
FoldAndAnnotateQParamsPass,
102102
FuseBatchNorm2dPass,
103+
FuseConcatPass,
103104
FuseConsecutiveConcatShapesPass,
104105
FuseConsecutiveRescalesPass,
105106
FuseConstantArgsPass,
@@ -532,6 +533,7 @@ def _tosa_pipeline(
532533
# Aten -> TOSA transformation passes
533534
self.add_passes(
534535
[
536+
FuseConcatPass(),
535537
RewriteUpsamplePass(),
536538
RewriteMaxPool2dPass(),
537539
RewriteConvPass(exported_program),

0 commit comments

Comments
 (0)