Arm backend: Remove use of tosa_dim_order#18948
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18948
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 4 Unrelated FailuresAs of commit bac92e5 with merge base 32a6cec ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
This patch removes the use of the ToTosaMemoryFormatPass and tosa_dim_order in favor of serializing directly to the contiguous stride shape. To allow this, modify all channels-last tosa dialect ops to be explicitly channels last and lower them with permutes, e.g. (1,2,3,3) -> aten conv -> (1,4,3,3) lowers to (1, 2, 3,3) -> (1,3,3,2) -> tosa conv -> (1,3,3,4) -> (1,4,3,3) To handle channels-last input/output, NormalizeDelegateIOLayoutPass permutes the shape of such inputs and outputs and inserts a permute to force it to be contiguous. This permute will then typically cancel out the top convolution permute, enabling zero-transpose graphs in many cases. Additionally, - Conv1D input is unsqueezed (N,C,L)->(N,C,1,L) instead of (N,C,L,1) to match avg_pool1d and max_pool2d shapes better - Bias to int16 convolutions are handled by lowering bias to a int48 tensor rather than using a seperate add operator. - Fix remove_permutes_around_elementwise_ops for permutes with negative indices. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: I7584e519cd876d63f1cafe98710f1c7fb0378581
ec8b798 to
9a64c73
Compare
There was a problem hiding this comment.
Pull request overview
This PR removes reliance on ToTosaMemoryFormatPass / tosa_dim_order for Arm TOSA lowering, moving to an explicit-layout strategy where channels-last lowering is expressed via permute_copy nodes and TOSA fake ops operate in NHWC/NDHWC shapes.
Changes:
- Stop propagating/consuming
tosa_dim_ordermetadata and instead lower layout changes via explicit pre/post permutes. - Update multiple TOSA dialect fake ops (conv/pool/resize) to compute shapes in NHWC/NDHWC and adjust conv weight formats accordingly.
- Add/adjust passes and tests to normalize delegate I/O layouts, rewrite pooling/upsample with explicit permutes, and validate int16/int48 bias handling.
Reviewed changes
Copilot reviewed 32 out of 32 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/transforms/remove_permutes_around_elementwise_ops.py | Makes permute parsing more robust (normalization/validation). |
| backends/arm/tosa/utils.py | tosa_shape now resolves SymInts without reordering (keeps API arg). |
| backends/arm/tosa/mapping.py | Stops using meta["tosa_dim_order"], defaults to identity dim order. |
| backends/arm/tosa/dialect/ops/transpose_conv2d.py | Updates transpose-conv output shape logic to NHWC + new weight format. |
| backends/arm/tosa/dialect/ops/resize.py | Updates resize fake shape inference to NHWC. |
| backends/arm/tosa/dialect/ops/max_pool2d.py | Updates max-pool fake shape inference to NHWC. |
| backends/arm/tosa/dialect/ops/depthwise_conv2d.py | Updates depthwise-conv fake shape inference + weight format assumptions. |
| backends/arm/tosa/dialect/ops/conv3d.py | Updates conv3d fake shape inference to NDHWC + new weight layout. |
| backends/arm/tosa/dialect/ops/conv2d.py | Updates conv2d fake shape inference to NHWC + new weight layout. |
| backends/arm/tosa/dialect/ops/avg_pool2d.py | Updates avg-pool fake shape inference to NHWC and fixes pad variable usage. |
| backends/arm/tosa/backend.py | Removes fake-tensor memory_format remapping based on tosa_dim_order. |
| backends/arm/test/passes/test_rewrite_conv_pass.py | Adds int16/a16w8 bias lowering tests and INT48 bias assertions. |
| backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py | New tests validating avg_pool2d rewrite to TOSA + permutes. |
| backends/arm/test/ops/test_unfold_copy.py | Removes xfails for negative-dim unfold_copy cases. |
| backends/arm/test/ops/test_pow.py | Adjusts xfail sets used for FP pow tests. |
| backends/arm/test/ops/test_conv2d.py | Tweaks U85 a16w8 conv2d test pipeline configuration. |
| backends/arm/test/misc/test_transpose_counts.py | Updates expected transpose counts and related model naming. |
| backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py | Updates expected shapes for DW conv2d dialect tests. |
| backends/arm/test/misc/test_tosa_dialect_conv2d.py | Updates expected shapes for conv2d dialect tests. |
| backends/arm/test/misc/test_process_node.py | New regression test for INT48 const normalization during serialization. |
| backends/arm/test/misc/test_high_rank_permute_view_invariants.py | Updates expected transpose counts in randomized permute/view chains. |
| backends/arm/test/misc/test_const_shape.py | Removes tests tied to ToTosaMemoryFormatPass / tosa_dim_order propagation. |
| backends/arm/process_node.py | Normalizes INT48 const storage to np.int64 before serializer addConst. |
| backends/arm/operators/op_tosa_shapes.py | Uses updated tosa_shape behavior (no dim-order reordering). |
| backends/arm/_passes/rewrite_upsample.py | Lowers upsample via NHWC pre/post permute_copy, with optional RESCALE. |
| backends/arm/_passes/rewrite_max_pool2d_pass.py | Inserts NHWC pre/post permute_copy around TOSA MAX_POOL2D. |
| backends/arm/_passes/rewrite_conv_pass.py | Rewrites convs using explicit permutes + per-conv weight placeholder rewrites. |
| backends/arm/_passes/rewrite_avg_pool2d_pass.py | Inserts NHWC pre/post permute_copy around TOSA AVG_POOL2D. |
| backends/arm/_passes/normalize_delegate_io_layout_pass.py | New pass to adjust delegated I/O boundary shapes and insert permutes. |
| backends/arm/_passes/conv1d_unsqueeze_pass.py | Adjusts conv1d→conv2d unsqueeze/squeeze shapes for new layout strategy. |
| backends/arm/_passes/arm_pass_manager.py | Removes ToTosaMemoryFormatPass, adds IO layout normalization + permute optimizations. |
| backends/arm/_passes/init.py | Exposes NormalizeDelegateIOLayoutPass from the pass package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| x = args[0] | ||
| pad_h, pad_w = args[3] | ||
| pad_h, pad_w = args[3] if len(args) > 3 else (0, 0) | ||
| # Make sure pad corresponds to TOSA | ||
| pad = [pad_h, pad_w, pad_h, pad_w] | ||
|
|
||
| _, _, h, w = x.data.shape | ||
| kernel_h, kernel_w = args[1] | ||
| stride_h, stride_w = args[2] | ||
| stride_h, stride_w = args[2] if len(args) > 2 else (1, 1) | ||
|
|
| kernel = _to_2tuple(args[1]) | ||
|
|
||
| if len(args) > 2 and args[2] is not None and len(args[2]) > 0: | ||
| if len(args) > 2 and args[2] is not None: |
| with graph_module.graph.inserting_after(tosa_resize_node): | ||
| rescale_node = create_node( | ||
| graph_module.graph, | ||
| exir_ops.backend.tosa.RESCALE.default, | ||
| ) | ||
| tosa_resize_node.replace_all_uses_with(rescale_node) | ||
| if input_dtype == torch.int16: | ||
| tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = ( | ||
| TosaSpecialDtype.INT48 | ||
| ) | ||
|
|
||
| rescale_node.args = ( | ||
| tosa_resize_node, | ||
| output_dtype, | ||
| [output_scale], | ||
| 0, # zero point | ||
| 0, # zero point | ||
| ) | ||
| node_replacement = rescale_node | ||
| node_replacement_fake = exir_ops.backend.tosa.RESCALE.default( | ||
| tosa_resize_node.meta["val"], output_dtype, [output_scale], 0, 0 | ||
| ) |
|
@claude review this PR. |
|
Claude finished @digantdesai's task in 2m 19s —— View job
Review SummaryThis PR replaces the implicit Overall this is a well-structured change. Below are specific issues I found, roughly sorted by severity. Issues1. Dead code in
|
| dim_order = meta["tosa_dim_order"] | ||
| else: | ||
| dim_order = tuple(range(len(shape))) | ||
| dim_order = tuple(range(len(shape))) |
There was a problem hiding this comment.
Assert its indeed NCHW?
There was a problem hiding this comment.
IMO the right way to do this is,
val = node.meta.get("val", None)
if val is not None and isinstance(val, torch.Tensor):
dim_order = val.dim_order()
There was a problem hiding this comment.
The expectation after this PR is that the graph will always be fully contiguous after the NormalizeDelegateIOLayout pass, except perhaps at the boundary nodes. I intend to cleanup this and remove all use of the dim-order in the TOSA serialization in a follow-up patch anyways, so I would say this quick-fix is fine for now.
| ) -> tuple[Any, bool]: | ||
| if isinstance(arg, torch.fx.Node): | ||
| output_fake = get_first_fake_tensor(arg) | ||
| dim_order = output_fake.dim_order() |
There was a problem hiding this comment.
can't this change during the lowering?
There was a problem hiding this comment.
It can, so it is important that this is the first pass run during the lowering!
| groups=groups[i], | ||
| bias=bias[i], | ||
| padding_mode=padding_mode[i], | ||
| ).to(dtype), |
There was a problem hiding this comment.
what if I do this?
| ).to(dtype).to(memory_format=[torch.channels_last] |
and similarly to get_inputs returned tensors?
There was a problem hiding this comment.
That would be handled by the NormalizeDelegateIOLayout by inserting two extra transposes (which would cancel out and leave no transposes in the final graph). This is tested in test_transpose_counts for example
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # Copyright 2026 Arm Limited and/or its affiliates. |
There was a problem hiding this comment.
I guess both of those are wrong? cc @digantdesai and @mergennachin
| rank = len(raw_permute) | ||
| normalized_permute = [d + rank if d < 0 else d for d in raw_permute] | ||
|
|
||
| if not all(0 <= d < rank for d in normalized_permute): | ||
| return None | ||
| if sorted(normalized_permute) != list(range(rank)): | ||
| return None | ||
| return normalized_permute |
There was a problem hiding this comment.
is this to handle -1 dims? What would be the issue?
There was a problem hiding this comment.
Yes it's for negative indices, which may appear in user-inserted permutes.
| # Node transformation passes (pre q/dq folding) | ||
| self.add_passes( | ||
| [ | ||
| NormalizeDelegateIOLayoutPass(exported_program), |
There was a problem hiding this comment.
nit: Maybe a comment about that this needs to be first:ish and why?
- Fix [] stride in avg/max_pool2d + add tests - Fix meta-data of rescale in rewrite_upsample - Merge conv weight permutes into singe help function - Nits: Remove dead code, stale comments, TEMP path Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Change-Id: I6aa9221467a575e1c42a40cc5ca7237a810f782d
|
looking good on performance on our side, I think this should get in. We can iterate on anything else later but this is a huge step. Thanks @AdrianLundell ! |
|
Great! I found a minor bug internally but should be good to go in on Monday |
Change-Id: Iea85c7761246e28a869ce4d832522612206da31b
|
@mcremon-meta Looks to be some required Meta Internal test blocking me from merging this after the rebase, could you take a look and see if you can unblock it? |
|
@digantdesai maybe it just need that re-sync you seem to need to do sometime if we touch PR that was tested "inside" ? |
|
I just imported it, @digantdesai let me know if there is anything else that needs to be done |
|
@mcremon-meta has imported this pull request. If you are a Meta employee, you can view this in D101990455. |
This patch removes the use of the ToTosaMemoryFormatPass and
tosa_dim_order in favor of serializing directly to the
contiguous stride shape.
To allow this, modify all channels-last tosa dialect ops to
be explicitly channels last and lower them with permutes, e.g.
(1,2,3,3) -> aten conv -> (1,4,3,3)
lowers to
(1, 2, 3,3) -> (1,3,3,2) -> tosa conv -> (1,3,3,4) -> (1,4,3,3)
To handle channels-last input/output, NormalizeDelegateIOLayoutPass
permutes the shape of such inputs and outputs and inserts a permute
to force it to be contiguous. This permute will then typically
cancel out the top convolution permute, enabling zero-transpose
graphs in many cases.
Additionally,
to match avg_pool1d and max_pool2d shapes better
tensor rather than using a seperate add operator.
indices.
Signed-off-by: Adrian Lundell adrian.lundell@arm.com