Commit 5cc8286
Add FuseConcatPass to eliminate redundant concat ops (#18827)
Summary:
Adds an FX-level pass that eliminates concat ops which can be proven structurally redundant before the TOSA backend / Vela compiler ever sees them. In the Gen2 Executorch ARM / Ethos-U stack, `torch.cat` lowers to TOSA `CONCAT`, which Vela converts to N `MemoryCopy` ops — real DMA on the NPU. Catching the obvious cases up front keeps the TOSA flatbuffer fed to Vela smaller, keeps debug graphs honest, and provides defensive coverage on TOSA targets where Vela's own scheduler doesn't run (e.g., the VGF backend).
Five rewrite patterns are handled (inspired by Espresso's `bolt/nn/espresso/transforms/remove_nops.py`):
1. **Single-input concat**: `cat([x], dim) ≡ x` — replace cat with x.
2. **Concat-then-slice (exact)**: `cat([a, b, ...], dim)` feeding a `slice_copy` that extracts exactly one original input — replace the slice with the corresponding cat input directly.
3. **Slice-then-concat (full)**: `cat([slice(x, d, s0, e0), slice(x, d, s1, e1), ...], dim)` reconstructing x exactly (contiguous slices covering the full source dimension) — replace cat with x.
4. **Concat-then-sub-slice**: a `slice_copy` whose range falls entirely within one cat input — replace with an adjusted slice on that input directly.
5. **Slice-then-concat (partial)**: contiguous slices of the same tensor concatenated back but covering only a sub-range of the source — replace with a single slice on the source.
## Empirical impact across the production EMG model fleet
Measured by running every `frl/ctrl/torchstream/torchstream/pt2/tests/test_emg_lowering_*` quantize+lower test with FuseConcatPass instrumented to log per-call counters, then comparing against the same target with the pass commented out in `arm_pass_manager.py`. All 8 model targets pass under both configurations.
| Model | Cats scanned | Eliminated | Pattern fired |
| --- | --- | --- | --- |
| cascade_classifier | 5 | **3 (60%)** | single-input |
| mux_fusion | 8 | **3 (38%)** | single-input |
| combined_control | 11 | **3 (27%)** | single-input |
| cascade_detector | 14 | 0 | — |
| cascade_hw_classifier | 12 | 0 | — |
| handwriting | 106 | 0 | — |
| wake | 11 | 0 | — |
| auth | 6 | 0 | — |
| **Total** | **173** | **9** | all single-input |
## Two findings worth highlighting
**Patterns 2–5 (the slice-related rewrites) never matched on any production EMG model.** PyTorch's Aten lowering on this fleet doesn't produce the cat↔slice algebra these patterns target. They remain useful for non-EMG TOSA workloads — and for the VGF backend where Vela's own optimizer doesn't run — but on the current EMG production set they are unexercised.
**Vela already folds single-input cats during compilation.** A before/after measurement on cascade_classifier (the model with the highest hit rate, 3/5 cats eliminated) shows Vela emits the same 9 `MemoryCopy` ops and consumes the same 481,339 NPU cycles either way. The eliminated cats reappear in the Vela operator table as `Reshape → MemoryCopy` instead of `Concat → MemoryCopy`. Total NPU runtime is unchanged. Pre-Vela artifacts do shrink (TOSA flatbuffer −16 KB / −0.68%, peak staging −1.5 KB / −0.45%), but post-Vela on-device performance is identical.
## Net effect
This pass is value-additive even where it doesn't move NPU cycles:
- Cleaner TOSA fed into Vela (~16 KB smaller per cascade_classifier instance).
- Slightly tighter peak staging during Vela scheduling (~1.5 KB).
- Defensive coverage for TOSA-only targets without a Vela-grade scheduler (notably the VGF / Vulkan path).
- More truthful FX / EXIR debug graphs — concats that were genuinely no-ops no longer show up in `model-explorer`, `delegation_metadata.json`, or the lowered graph dumps.
It does **not** produce measurable NPU cycle savings on the current EMG production fleet. The patterns that would have produced real Vela savings (cat↔slice algebra) don't appear in these models.
Authored with Claude.
Differential Revision: D976670691 parent ada8e35 commit 5cc8286
4 files changed
Lines changed: 872 additions & 0 deletions
File tree
- backends/arm
- _passes
- test/passes
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
103 | 103 | | |
104 | 104 | | |
105 | 105 | | |
| 106 | + | |
106 | 107 | | |
107 | 108 | | |
108 | 109 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
100 | 100 | | |
101 | 101 | | |
102 | 102 | | |
| 103 | + | |
103 | 104 | | |
104 | 105 | | |
105 | 106 | | |
| |||
532 | 533 | | |
533 | 534 | | |
534 | 535 | | |
| 536 | + | |
535 | 537 | | |
536 | 538 | | |
537 | 539 | | |
| |||
0 commit comments