Skip to content

Arm backend: Remove use of tosa_dim_order#18948

Merged
AdrianLundell merged 4 commits intopytorch:mainfrom
AdrianLundell:change-1241864
Apr 27, 2026
Merged

Arm backend: Remove use of tosa_dim_order#18948
AdrianLundell merged 4 commits intopytorch:mainfrom
AdrianLundell:change-1241864

Conversation

@AdrianLundell
Copy link
Copy Markdown
Collaborator

@AdrianLundell AdrianLundell commented Apr 16, 2026

This patch removes the use of the ToTosaMemoryFormatPass and
tosa_dim_order in favor of serializing directly to the
contiguous stride shape.

To allow this, modify all channels-last tosa dialect ops to
be explicitly channels last and lower them with permutes, e.g.
(1,2,3,3) -> aten conv -> (1,4,3,3)
lowers to
(1, 2, 3,3) -> (1,3,3,2) -> tosa conv -> (1,3,3,4) -> (1,4,3,3)

To handle channels-last input/output, NormalizeDelegateIOLayoutPass
permutes the shape of such inputs and outputs and inserts a permute
to force it to be contiguous. This permute will then typically
cancel out the top convolution permute, enabling zero-transpose
graphs in many cases.

Additionally,

  • Conv1D input is unsqueezed (N,C,L)->(N,C,1,L) instead of (N,C,L,1)
    to match avg_pool1d and max_pool2d shapes better
  • Bias to int16 convolutions are handled by lowering bias to a int48
    tensor rather than using a seperate add operator.
  • Fix remove_permutes_around_elementwise_ops for permutes with negative
    indices.

Signed-off-by: Adrian Lundell adrian.lundell@arm.com

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 16, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18948

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 4 Unrelated Failures

As of commit bac92e5 with merge base 32a6cec (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@AdrianLundell AdrianLundell added the partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm label Apr 16, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Comment thread backends/arm/_passes/to_tosa_memory_format_pass.py Outdated
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 16, 2026
This patch removes the use of the ToTosaMemoryFormatPass and
tosa_dim_order in favor of serializing directly to the
contiguous stride shape.

To allow this, modify all channels-last tosa dialect ops to
be explicitly channels last and lower them with permutes, e.g.
   (1,2,3,3) -> aten conv -> (1,4,3,3)
lowers to
   (1, 2, 3,3) -> (1,3,3,2) -> tosa conv -> (1,3,3,4) -> (1,4,3,3)

To handle channels-last input/output, NormalizeDelegateIOLayoutPass
permutes the shape of such inputs and outputs and inserts a permute
to force it to be contiguous. This permute will then typically
cancel out the top convolution permute, enabling zero-transpose
graphs in many cases.

Additionally,
- Conv1D input is unsqueezed (N,C,L)->(N,C,1,L) instead of (N,C,L,1)
  to match avg_pool1d and max_pool2d shapes better
- Bias to int16 convolutions are handled by lowering bias to a int48
  tensor rather than using a seperate add operator.
- Fix remove_permutes_around_elementwise_ops for permutes with negative
  indices.

Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
Change-Id: I7584e519cd876d63f1cafe98710f1c7fb0378581
@github-actions github-actions Bot added ciflow/trunk module: arm Issues related to arm backend labels Apr 22, 2026
@AdrianLundell AdrianLundell marked this pull request as ready for review April 22, 2026 14:36
Copilot AI review requested due to automatic review settings April 22, 2026 14:36
@AdrianLundell AdrianLundell changed the title [WIP] Arm backend: Remove use of tosa_dim_order Arm backend: Remove use of tosa_dim_order Apr 22, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR removes reliance on ToTosaMemoryFormatPass / tosa_dim_order for Arm TOSA lowering, moving to an explicit-layout strategy where channels-last lowering is expressed via permute_copy nodes and TOSA fake ops operate in NHWC/NDHWC shapes.

Changes:

  • Stop propagating/consuming tosa_dim_order metadata and instead lower layout changes via explicit pre/post permutes.
  • Update multiple TOSA dialect fake ops (conv/pool/resize) to compute shapes in NHWC/NDHWC and adjust conv weight formats accordingly.
  • Add/adjust passes and tests to normalize delegate I/O layouts, rewrite pooling/upsample with explicit permutes, and validate int16/int48 bias handling.

Reviewed changes

Copilot reviewed 32 out of 32 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
backends/transforms/remove_permutes_around_elementwise_ops.py Makes permute parsing more robust (normalization/validation).
backends/arm/tosa/utils.py tosa_shape now resolves SymInts without reordering (keeps API arg).
backends/arm/tosa/mapping.py Stops using meta["tosa_dim_order"], defaults to identity dim order.
backends/arm/tosa/dialect/ops/transpose_conv2d.py Updates transpose-conv output shape logic to NHWC + new weight format.
backends/arm/tosa/dialect/ops/resize.py Updates resize fake shape inference to NHWC.
backends/arm/tosa/dialect/ops/max_pool2d.py Updates max-pool fake shape inference to NHWC.
backends/arm/tosa/dialect/ops/depthwise_conv2d.py Updates depthwise-conv fake shape inference + weight format assumptions.
backends/arm/tosa/dialect/ops/conv3d.py Updates conv3d fake shape inference to NDHWC + new weight layout.
backends/arm/tosa/dialect/ops/conv2d.py Updates conv2d fake shape inference to NHWC + new weight layout.
backends/arm/tosa/dialect/ops/avg_pool2d.py Updates avg-pool fake shape inference to NHWC and fixes pad variable usage.
backends/arm/tosa/backend.py Removes fake-tensor memory_format remapping based on tosa_dim_order.
backends/arm/test/passes/test_rewrite_conv_pass.py Adds int16/a16w8 bias lowering tests and INT48 bias assertions.
backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py New tests validating avg_pool2d rewrite to TOSA + permutes.
backends/arm/test/ops/test_unfold_copy.py Removes xfails for negative-dim unfold_copy cases.
backends/arm/test/ops/test_pow.py Adjusts xfail sets used for FP pow tests.
backends/arm/test/ops/test_conv2d.py Tweaks U85 a16w8 conv2d test pipeline configuration.
backends/arm/test/misc/test_transpose_counts.py Updates expected transpose counts and related model naming.
backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py Updates expected shapes for DW conv2d dialect tests.
backends/arm/test/misc/test_tosa_dialect_conv2d.py Updates expected shapes for conv2d dialect tests.
backends/arm/test/misc/test_process_node.py New regression test for INT48 const normalization during serialization.
backends/arm/test/misc/test_high_rank_permute_view_invariants.py Updates expected transpose counts in randomized permute/view chains.
backends/arm/test/misc/test_const_shape.py Removes tests tied to ToTosaMemoryFormatPass / tosa_dim_order propagation.
backends/arm/process_node.py Normalizes INT48 const storage to np.int64 before serializer addConst.
backends/arm/operators/op_tosa_shapes.py Uses updated tosa_shape behavior (no dim-order reordering).
backends/arm/_passes/rewrite_upsample.py Lowers upsample via NHWC pre/post permute_copy, with optional RESCALE.
backends/arm/_passes/rewrite_max_pool2d_pass.py Inserts NHWC pre/post permute_copy around TOSA MAX_POOL2D.
backends/arm/_passes/rewrite_conv_pass.py Rewrites convs using explicit permutes + per-conv weight placeholder rewrites.
backends/arm/_passes/rewrite_avg_pool2d_pass.py Inserts NHWC pre/post permute_copy around TOSA AVG_POOL2D.
backends/arm/_passes/normalize_delegate_io_layout_pass.py New pass to adjust delegated I/O boundary shapes and insert permutes.
backends/arm/_passes/conv1d_unsqueeze_pass.py Adjusts conv1d→conv2d unsqueeze/squeeze shapes for new layout strategy.
backends/arm/_passes/arm_pass_manager.py Removes ToTosaMemoryFormatPass, adds IO layout normalization + permute optimizations.
backends/arm/_passes/init.py Exposes NormalizeDelegateIOLayoutPass from the pass package.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/arm/test/ops/test_conv2d.py Outdated
Comment thread backends/arm/test/misc/test_transpose_counts.py
Comment thread backends/arm/test/misc/test_transpose_counts.py Outdated
Comment on lines 35 to 43
x = args[0]
pad_h, pad_w = args[3]
pad_h, pad_w = args[3] if len(args) > 3 else (0, 0)
# Make sure pad corresponds to TOSA
pad = [pad_h, pad_w, pad_h, pad_w]

_, _, h, w = x.data.shape
kernel_h, kernel_w = args[1]
stride_h, stride_w = args[2]
stride_h, stride_w = args[2] if len(args) > 2 else (1, 1)

kernel = _to_2tuple(args[1])

if len(args) > 2 and args[2] is not None and len(args[2]) > 0:
if len(args) > 2 and args[2] is not None:
Comment on lines 226 to +246
with graph_module.graph.inserting_after(tosa_resize_node):
rescale_node = create_node(
graph_module.graph,
exir_ops.backend.tosa.RESCALE.default,
)
tosa_resize_node.replace_all_uses_with(rescale_node)
if input_dtype == torch.int16:
tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = (
TosaSpecialDtype.INT48
)

rescale_node.args = (
tosa_resize_node,
output_dtype,
[output_scale],
0, # zero point
0, # zero point
)
node_replacement = rescale_node
node_replacement_fake = exir_ops.backend.tosa.RESCALE.default(
tosa_resize_node.meta["val"], output_dtype, [output_scale], 0, 0
)
@digantdesai
Copy link
Copy Markdown
Contributor

@claude review this PR.

@claude
Copy link
Copy Markdown

claude Bot commented Apr 22, 2026

Claude finished @digantdesai's task in 2m 19s —— View job


  • Read and understand core architectural changes
  • Review NormalizeDelegateIOLayoutPass (new pass)
  • Review rewrite passes (conv, avg_pool2d, max_pool2d, upsample)
  • Review TOSA dialect op shape changes
  • Review remove_permutes_around_elementwise_ops fix
  • Review INT48 bias handling in process_node.py
  • Review conv1d unsqueeze dimension changes
  • Review test changes
  • Post review summary

Review Summary

This PR replaces the implicit tosa_dim_order metadata approach with explicit permute_copy nodes around channels-last TOSA ops. The approach is sound — explicit permutes are more transparent, compose better with graph optimizations (e.g. RemovePermutesAroundElementwiseOps can cancel them), and NormalizeDelegateIOLayoutPass handles the boundary correctly so channels-last I/O produces zero-transpose graphs in common cases.

Overall this is a well-structured change. Below are specific issues I found, roughly sorted by severity.


Issues

1. Dead code in RewriteConvPass — unreachable else branch for int16

rewrite_with_explicit_permute is now set to True in every code path (transposed, conv3d, depthwise, conv2d). The else branch at line 715 is unreachable:

# rewrite_conv_pass.py:700-727
if rewrite_with_explicit_permute:       # always True
    ...
else:                                    # <-- dead code
    has_bias = len(node.meta["input_qparams"]) > 2
    ...

This dead code still references the old bias-as-separate-add pattern for int16 convolutions. It should be removed for clarity, since the bias-to-INT48-tensor approach is now the only path.

2. Dead null checks — pre_permute_dims is None after assignment

In multiple branches of call(), pre_permute_dims is checked for None immediately after being set to a tuple literal:

# rewrite_conv_pass.py:532-535
pre_permute_dims = NHWC_ORDER
post_permute_dims = NHWC_INVERSE_ORDER
if pre_permute_dims is None:            # <-- always False
    raise RuntimeError(...)

This pattern repeats at lines 583-586, 609-612, and 637-640. These were likely useful when the assignment was conditional, but now they are unreachable assertions.

3. RewriteAvgPool2dPass — fragile stride/padding defaults

rewrite_avg_pool2d_pass.py:42:

stride_h, stride_w = args[2] if len(args) > 2 else (1, 1)

Per the Edge schema, avg_pool2d can pass an empty list [] as the stride argument (meaning "use kernel_size"). Unpacking [] will raise ValueError. The padding arg at line 36 is similarly fragile — a scalar 0 would fail unpacking. Consider normalizing like:

stride = args[2] if len(args) > 2 else []
if not stride:
    stride = args[1]  # default to kernel_size
stride_h, stride_w = stride

The test_rewrite_avg_pool2d_pass.py test for AvgPool2dWithoutStride does not exercise the stride=[] edge case because avg_pool2d(x, kernel_size=3) passes stride as [3, 3] in the Edge IR after decomposition. However, if a caller provides stride=[] directly, this will fail at runtime.

4. RewriteMaxPool2dPass — same empty-stride concern

rewrite_max_pool2d_pass.py:39-42:

if len(args) > 2 and args[2] is not None:
    stride = _to_2tuple(args[2])
else:
    stride = kernel

This handles None but not []. If args[2] is [], _to_2tuple([]) returns (), and subsequent indexing stride[0] will raise IndexError. The previous implementation (based on len(args[2]) > 0) correctly guarded against this. Copilot's review flagged this as well.

5. RewriteUpsamplePassrescale_node missing metadata

rewrite_upsample.py:227-230: The rescale_node is created via create_node(graph, op_target) without from_node or meta['val']:

rescale_node = create_node(
    graph_module.graph,
    exir_ops.backend.tosa.RESCALE.default,
)

Unlike the conv pass which uses from_node=node to propagate debug metadata and tracing info, this node is created bare. While node_replacement_fake is computed, it's never assigned to rescale_node.meta["val"] — only used to compute post_permute.meta["val"]. This means subsequent passes that inspect rescale_node.meta["val"] (e.g. for shape inference) would find stale or missing data. Consider:

rescale_node = create_node(
    graph_module.graph,
    exir_ops.backend.tosa.RESCALE.default,
    from_node=node,
)
rescale_node.meta["val"] = node_replacement_fake

6. Stale comments in test_transpose_counts.py

Lines 489, 509, 529: Several test cases retain inline comments like "The test crashes before reaching the transpose count" but the cases now have concrete expected counts and are not in xfails:

"matmul_rank4_channels_last": TransposeCountCase(
    ...
    2,  # The test crashes before reaching the transpose count   <-- stale
),

These comments should be removed or updated to reflect the current behavior.

7. test_conv2d.pycustom_path="TEMP" in CI

# test_conv2d.py:761
custom_path="TEMP"

Setting custom_path="TEMP" creates artifacts in a fixed on-disk location, which can cause test pollution and flaky failures in parallel CI runs. Prefer the default collate path or a per-test temporary directory.


Observations (non-blocking)

Weight rewrite strategy: The switch from mutating state_dict in-place (_reshape_weights) to creating new constant placeholders per conv node (_rewrite_conv2d_weight_to_ohwi, etc.) is cleaner and avoids aliasing issues when the same weight is shared across multiple conv ops. The TosaSpecialDtype metadata is properly propagated to the new weight nodes.

NormalizeDelegateIOLayoutPass: The new pass is clean and handles both inputs and outputs correctly. The recursive handling of tuple/list output structures via _rewrite_output_arg is thorough. Good use of _inverse_permutation for computing the transposes.

RemovePermutesAroundElementwiseOps: The fix for negative permute indices (normalizing via d + rank if d < 0) is correct and the validation checks (sorted(normalized_permute) == list(range(rank))) are appropriate. The output node boundary guard is a good addition.

Conv1d unsqueeze change: The switch from (N,C,L,1) to (N,C,1,L) is the right call for consistency with avg_pool1d/max_pool2d shapes.

INT48 const normalization: The _prepare_const_values_for_tosa_dtype helper in process_node.py correctly ensures int32 values are cast to int64 before INT48 serialization. The regression test (test_process_node.py) validates the full round-trip.

tosa_shape simplification: Removing the dim-order reordering from tosa_shape() is consistent with the overall approach of this PR. Since shapes are now always in contiguous order (with explicit permute nodes handling layout), dim-order in tosa_shape was redundant.


dim_order = meta["tosa_dim_order"]
else:
dim_order = tuple(range(len(shape)))
dim_order = tuple(range(len(shape)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert its indeed NCHW?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the right way to do this is,

    val = node.meta.get("val", None)
    if val is not None and isinstance(val, torch.Tensor):
        dim_order = val.dim_order()

Copy link
Copy Markdown
Collaborator Author

@AdrianLundell AdrianLundell Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expectation after this PR is that the graph will always be fully contiguous after the NormalizeDelegateIOLayout pass, except perhaps at the boundary nodes. I intend to cleanup this and remove all use of the dim-order in the TOSA serialization in a follow-up patch anyways, so I would say this quick-fix is fine for now.

) -> tuple[Any, bool]:
if isinstance(arg, torch.fx.Node):
output_fake = get_first_fake_tensor(arg)
dim_order = output_fake.dim_order()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this change during the lowering?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, so it is important that this is the first pass run during the lowering!

groups=groups[i],
bias=bias[i],
padding_mode=padding_mode[i],
).to(dtype),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if I do this?

Suggested change
).to(dtype).to(memory_format=[torch.channels_last]

and similarly to get_inputs returned tensors?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be handled by the NormalizeDelegateIOLayout by inserting two extra transposes (which would cancel out and leave no transposes in the final graph). This is tested in test_transpose_counts for example

Comment on lines 1 to +3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2026 Arm Limited and/or its affiliates.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess both of those are wrong? cc @digantdesai and @mergennachin

Comment on lines +281 to +288
rank = len(raw_permute)
normalized_permute = [d + rank if d < 0 else d for d in raw_permute]

if not all(0 <= d < rank for d in normalized_permute):
return None
if sorted(normalized_permute) != list(range(rank)):
return None
return normalized_permute
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this to handle -1 dims? What would be the issue?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's for negative indices, which may appear in user-inserted permutes.

# Node transformation passes (pre q/dq folding)
self.add_passes(
[
NormalizeDelegateIOLayoutPass(exported_program),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe a comment about that this needs to be first:ish and why?

- Fix [] stride in avg/max_pool2d + add tests
- Fix meta-data of rescale in rewrite_upsample
- Merge conv weight permutes into singe help function
- Nits: Remove dead code, stale comments, TEMP path

Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
Change-Id: I6aa9221467a575e1c42a40cc5ca7237a810f782d
Copilot AI review requested due to automatic review settings April 23, 2026 14:28
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@mcremon-meta
Copy link
Copy Markdown
Contributor

looking good on performance on our side, I think this should get in. We can iterate on anything else later but this is a huge step. Thanks @AdrianLundell !

@AdrianLundell
Copy link
Copy Markdown
Collaborator Author

Great! I found a minor bug internally but should be good to go in on Monday

Change-Id: Iea85c7761246e28a869ce4d832522612206da31b
@AdrianLundell
Copy link
Copy Markdown
Collaborator Author

@mcremon-meta Looks to be some required Meta Internal test blocking me from merging this after the rebase, could you take a look and see if you can unblock it?

@zingo
Copy link
Copy Markdown
Collaborator

zingo commented Apr 27, 2026

@digantdesai maybe it just need that re-sync you seem to need to do sometime if we touch PR that was tested "inside" ?

@mcremon-meta
Copy link
Copy Markdown
Contributor

I just imported it, @digantdesai let me know if there is anything else that needs to be done

@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 27, 2026

@mcremon-meta has imported this pull request. If you are a Meta employee, you can view this in D101990455.

@AdrianLundell AdrianLundell merged commit 1bb039f into pytorch:main Apr 27, 2026
433 of 441 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: arm Issues related to arm backend partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants