Skip to content

Add FuseTosaTransposesPass with elementwise propagation#18947

Open
Ninja91 wants to merge 1 commit intopytorch:mainfrom
Ninja91:export-D92901685
Open

Add FuseTosaTransposesPass with elementwise propagation#18947
Ninja91 wants to merge 1 commit intopytorch:mainfrom
Ninja91:export-D92901685

Conversation

@Ninja91
Copy link
Copy Markdown
Contributor

@Ninja91 Ninja91 commented Apr 16, 2026

Summary:
Part of the Ethos-U55/U85 optimization stack targeting -22.6% NPU cycle reduction on Wake EMG.

Adds FuseTosaTransposesPass that eliminates redundant TOSA TRANSPOSE operations through four optimizations:

  1. Identity elimination — remove TRANSPOSE with identity permutation [0,1,2,3]
  2. Inverse-pair cancellation — remove TRANSPOSE→TRANSPOSE pairs that compose to identity
  3. Composition — fuse consecutive non-inverse TRANSPOSEs into a single TRANSPOSE
  4. Propagation — move TRANSPOSE through layout-agnostic ops (RESCALE, elementwise) to enable more cancellations

The propagation pattern handles the common case where ToTosaMemoryFormatPass inserts TRANSPOSE pairs around view_copy rank boundaries, with RESCALE and elementwise ops in between:
TRANSPOSE(p) → RESCALE → relu → RESCALE → TRANSPOSE(inv(p)) → RESCALE → relu → RESCALE

For binary elementwise ops (ADD, MUL, SUB), propagation is safe only when the non-primary operand is broadcast-safe (scalar or 1-element tensor).

Impact

Combined with FuseConsecutiveRescalesPass (next diff), reduces total NPU cycles on Wake/U55 by 22.6%. TRANSPOSE elimination directly reduces TRANSPOSE HW ops and enables the RESCALE fusion pass to find more fusible pairs.

Reviewed By: davidxili

Differential Revision: D92901685

Summary:
Part of the Ethos-U55/U85 optimization stack targeting -22.6% NPU cycle reduction on Wake EMG.

Adds FuseTosaTransposesPass that eliminates redundant TOSA TRANSPOSE operations through four optimizations:
1. **Identity elimination** — remove TRANSPOSE with identity permutation [0,1,2,3]
2. **Inverse-pair cancellation** — remove TRANSPOSE→TRANSPOSE pairs that compose to identity
3. **Composition** — fuse consecutive non-inverse TRANSPOSEs into a single TRANSPOSE
4. **Propagation** — move TRANSPOSE through layout-agnostic ops (RESCALE, elementwise) to enable more cancellations

The propagation pattern handles the common case where ToTosaMemoryFormatPass inserts TRANSPOSE pairs around view_copy rank boundaries, with RESCALE and elementwise ops in between:
    TRANSPOSE(p) → RESCALE → relu → RESCALE → TRANSPOSE(inv(p))  →  RESCALE → relu → RESCALE

For binary elementwise ops (ADD, MUL, SUB), propagation is safe only when the non-primary operand is broadcast-safe (scalar or 1-element tensor).

## Impact
Combined with FuseConsecutiveRescalesPass (next diff), reduces total NPU cycles on Wake/U55 by 22.6%. TRANSPOSE elimination directly reduces TRANSPOSE HW ops and enables the RESCALE fusion pass to find more fusible pairs.

Reviewed By: davidxili

Differential Revision: D92901685
@Ninja91 Ninja91 requested a review from digantdesai as a code owner April 16, 2026 16:37
Copilot AI review requested due to automatic review settings April 16, 2026 16:37
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 16, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18947

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures, 2 Unrelated Failures

As of commit aff53bf with merge base a489707 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 16, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 16, 2026

@Ninja91 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92901685.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Introduces a new Arm backend optimization pass to eliminate redundant TOSA TRANSPOSE operations (including propagation through elementwise ops) as part of the Ethos-U55/U85 optimization stack, and adds targeted tests/scripts to validate transpose reduction behavior.

Changes:

  • Add FuseTosaTransposesPass implementing identity elimination, inverse-pair cancellation, composition, and elementwise propagation.
  • Integrate the new pass into the Arm TOSA pipeline immediately after ToTosaMemoryFormatPass.
  • Add unit tests for common patterns (conv chains, pooling, fan-out) plus propagation-through-elementwise cases, and a standalone comparison script.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
backends/arm/_passes/fuse_tosa_transposes_pass.py New optimization pass to remove/fuse/cancel TOSA transposes, including propagation through elementwise ops.
backends/arm/_passes/arm_pass_manager.py Wires FuseTosaTransposesPass into the standard Arm TOSA pipeline after memory-format transposes are inserted.
backends/arm/_passes/init.py Exposes FuseTosaTransposesPass via the Arm passes package exports.
backends/arm/test/passes/test_fuse_tosa_transposes.py Adds unit tests validating transpose counts and functional correctness, including propagation cases.
backends/arm/test/passes/fuse_tosa_transposes_comparison.py Adds a runnable script to compare transpose counts pre/post optimization on representative models.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

f"(eliminated {before_count - after_count}), iterations={iteration}"
)

return PassResult(graph_module, modified_overall)
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FuseTosaTransposesPass.call() rewires/removes TRANSPOSE nodes in a way that can change intermediate tensor shapes (especially for the propagation pattern), but it never calls super().call(graph_module) to retrace and refresh meta['val'] / fake-tensor metadata. Several downstream Arm passes rely on get_first_fake_tensor(node) and meta['val'] being accurate; leaving stale metadata here can break later passes or lead to incorrect lowering. After graph mutations + recompile, invoke super().call(graph_module) (like ToTosaMemoryFormatPass/BroadcastArgsPass do) or otherwise re-run the metadata propagation.

Suggested change
return PassResult(graph_module, modified_overall)
refreshed_result = super().call(graph_module)
return PassResult(refreshed_result.graph_module, modified_overall)

Copilot uses AI. Check for mistakes.
Comment on lines +236 to +261
def test_identity_transpose_elimination() -> None:
"""
Test that identity transposes are eliminated.
Uses a simple pass-through module.
"""

class IdentityModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x

def get_inputs(self) -> input_t:
return (torch.rand(1, 16, 8, 8),)

module = IdentityModule()
pipeline = PassPipeline[input_t](
module,
module.get_inputs(),
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
passes_with_exported_program=[
ToTosaMemoryFormatPass,
FuseTosaTransposesPass,
],
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_identity_transpose_elimination() claims to verify identity transpose removal, but the pipeline is created without ops_before/ops_after assertions (and no other explicit checks). As written, this test only verifies the pass pipeline runs, not that identity transposes were eliminated. Consider adding an explicit check (e.g., expected TRANSPOSE count or verifying no identity-permutation TRANSPOSE nodes remain) so the test fails if the optimization regresses.

Copilot uses AI. Check for mistakes.
Comment on lines +95 to +153
def run_comparison(
model: nn.Module,
inputs: Tuple[torch.Tensor, ...],
model_name: str
) -> Dict[str, int]:
"""
Run comparison of TRANSPOSE counts with and without FuseTosaTransposesPass.
"""
print(f"\n{'='*60}")
print(f"Testing: {model_name}")
print(f"{'='*60}")

# Run pipeline WITHOUT FuseTosaTransposesPass (baseline)
print("\n[1] Running WITHOUT FuseTosaTransposesPass...")
pipeline_baseline = PassPipeline(
model,
inputs,
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
passes_with_exported_program=[ToTosaMemoryFormatPass],
)
pipeline_baseline.pop_stage("run_method_and_compare_outputs")
result_baseline = pipeline_baseline.run()

baseline_count = count_transposes(result_baseline.graph_module)
print(f" TRANSPOSE ops (baseline): {baseline_count}")

# Run pipeline WITH FuseTosaTransposesPass (optimized)
print("\n[2] Running WITH FuseTosaTransposesPass...")
pipeline_optimized = PassPipeline(
model,
inputs,
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
passes_with_exported_program=[
ToTosaMemoryFormatPass,
FuseTosaTransposesPass,
],
)
pipeline_optimized.pop_stage("run_method_and_compare_outputs")
result_optimized = pipeline_optimized.run()

optimized_count = count_transposes(result_optimized.graph_module)
print(f" TRANSPOSE ops (optimized): {optimized_count}")

# Calculate reduction
reduction = baseline_count - optimized_count
reduction_pct = (reduction / baseline_count * 100) if baseline_count > 0 else 0

print(f"\n[3] Results Summary:")
print(f" Baseline: {baseline_count} TRANSPOSE ops")
print(f" Optimized: {optimized_count} TRANSPOSE ops")
print(f" Reduction: {reduction} ops ({reduction_pct:.1f}%)")

return {
"model": model_name,
"baseline": baseline_count,
"optimized": optimized_count,
"reduction": reduction,
"reduction_pct": reduction_pct,
}
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_comparison() is annotated to return Dict[str, int], but the returned dict includes a string value ('model') and a float ('reduction_pct'). With pyre-strict enabled, this should be a type error. Use a TypedDict/dataclass for the result, or widen the return type (e.g., Dict[str, object] or dict[str, int | float | str]).

Copilot uses AI. Check for mistakes.
Comment on lines +52 to +64
def _target_name(target: object) -> str:
"""Extract a recognizable name from a node target for string matching."""
name = str(target)
# Handle exir_ops.backend.tosa.RESCALE.default → "RESCALE"
# Handle exir_ops.edge.aten.add.Tensor → "add.Tensor"
parts = name.rsplit(".", 2)
if len(parts) >= 2:
# For "backend__ops_tosa_RESCALE_default" patterns
if "RESCALE" in name:
return "RESCALE"
# Return the last two parts for ATen ops: "add.Tensor", "clamp.default", etc.
return ".".join(parts[-2:])
return name
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_target_name() uses str(target), but EdgeOpOverload/BackendOpOverload str includes schema text (see exir/dialects/edge/_ops.py), so the rsplit/join logic returns strings like "add.Tensor>: schema = ..." that will never match entries in _UNARY_ELEMENTWISE_TARGET_NAMES/_BINARY_ELEMENTWISE_TARGET_NAMES. This effectively disables elementwise propagation for edge ops (and likely makes the new propagation tests fail). Prefer using target.name (when present) and parsing that, or otherwise strip the schema suffix before matching.

Copilot uses AI. Check for mistakes.
@AdrianLundell
Copy link
Copy Markdown
Collaborator

Thanks for the PR, the tosa.TRANSPOSE is however in the process of being deprecated in favor of using regular permutes consistently, see #18948. This rework should solve the same problems as you are fixing here though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants