Skip to content

Commit 40bde3c

Browse files
mcremon-metameta-codesync[bot]
authored andcommitted
Replace tosa_dim_order with explicit NCHW↔NHWC permutes (#19015)
Summary: Pull Request resolved: #19015 Replace implicit `tosa_dim_order`-based layout handling with explicit `permute_copy` ops around TOSA operators that require NHWC layout. ### Rewrite passes insert explicit NCHW↔NHWC permutes `RewriteConvPass`, `RewriteAvgPool2dPass`, and `RewriteMaxPool2dPass` now insert `aten.permute_copy` nodes (NCHW→NHWC before the TOSA op, NHWC→NCHW after) instead of relying on `ToTosaMemoryFormatPass` for layout conversion. This makes layout transitions visible in the graph. ### Grouped conv decomposition in NHWC `RewriteConvPass` decomposes grouped convolutions (non-depthwise) into per-group `TOSA.CONV2D` ops operating entirely in NHWC, with a single input/output permute pair wrapping the whole group. Supports INT8, INT16 (with and without bias) quantisation paths, including the full INT16+bias chain: CONV2D → RESCALE(INT48→INT32) → ADD(bias) → RESCALE(INT32→INT16). ### `ToTosaMemoryFormatPass` scoped down Now only assigns non-identity dim_order to parameter/buffer placeholders (for weight serialisation) and graph I/O. Inserts `permute_copy` instead of `tosa.TRANSPOSE`. Skips users that already carry a matching permute (inserted by the rewrite passes). ### TOSA dialect op metas expect NHWC All TOSA op meta functions (`CONV2D`, `CONV3D`, `DEPTHWISE_CONV2D`, `AVG_POOL2D`, `MAX_POOL2D`, `TRANSPOSE_CONV2D`) now assume NHWC input layout and produce NHWC output shapes. ### Removed `tosa_dim_order` shape remapping `tosa_shape()` no longer reorders dimensions—just resolves symints. `_get_matching_fake_tensor()` returns `node.meta["val"]` directly. Serialisation mapping always uses identity dim_order. ### Operator serialisation simplified `op_amax`, `op_amin`, `op_any`, `op_cat`, `op_sum`, and `op_permute` no longer remap reduction/concat axes through `dim_order` since tensors are already in the layout expected by TOSA. ### Permute optimisation passes added Six shared passes from `executorch/backends/transforms/` are now run after TOSA lowering to fuse, cancel, and simplify the permutes introduced above: - `RemovePermutesAroundElementwiseOps` (extended for `RESCALE`) - `FuseTransposeOrPermuteOpPairsPass` (extended for `RESCALE`) - `ReplaceNopTransposeOrPermuteWithViewPass` - `PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView` - `FuseCascadedTransposeOrPermuteOps` - `FuseCascadedViewOps` ### Removed passes `DecomposeConvWithInt16ActivationPass` and `DecomposeGroupedConvPass` are removed—their logic is now handled inline by `RewriteConvPass`. `RewriteSlicePass` is repositioned after the permute optimisations. ### Ethos-U55 partitioner simplified The dual NCHW/NHWC permute constraint check is removed since tensors are always in the expected layout at partition time. Differential Revision: D100712787
1 parent cb8489e commit 40bde3c

33 files changed

Lines changed: 960 additions & 1214 deletions

backends/arm/_passes/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ runtime.python_library(
2525
"//executorch/backends/arm:common",
2626
"//executorch/backends/arm/tosa:utils",
2727
"//executorch/backends/arm/tosa/dialect:lib",
28+
"//executorch/backends/transforms:fuse_cascaded_transpose_or_permute_ops",
29+
"//executorch/backends/transforms:fuse_cascaded_view_ops",
30+
"//executorch/backends/transforms:fuse_transpose_or_permute_op_pairs_pass",
31+
"//executorch/backends/transforms:remove_permutes_around_elementwise_ops",
32+
"//executorch/backends/transforms:postpone_permute_below_squeeze_view",
33+
"//executorch/backends/transforms:replace_nop_transpose_or_permute_with_view",
2834
"//executorch/exir:lib",
2935
],
3036
)

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from . import arm_pass_utils # noqa
88
from .arm_pass import ArmPass # noqa # usort: skip
99
from .accumulate_index_put_pass import AccumulateIndexPutPass # noqa
10-
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
1110
from .broadcast_args_pass import BroadcastArgsPass # noqa
1211
from .canonicalize_gather_pass import CanonicalizeGatherPass # noqa
1312
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
@@ -165,7 +164,6 @@
165164
from .rewrite_upsample import RewriteUpsamplePass # noqa
166165
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
167166
from .size_adjust_input_pass import SizeAdjustInputPass # noqa
168-
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
169167
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
170168
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
171169
from .replace_inf_and_limit_values_pass import ( # noqa # usort: skip

backends/arm/_passes/annotate_output_dim_order_pass.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

backends/arm/_passes/arm_pass_manager.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from executorch.backends.arm._passes import (
1414
AccumulateIndexPutPass,
15-
AnnotateOutputDimOrderPass,
1615
BroadcastArgsPass,
1716
CanonicalizeGatherPass,
1817
CastInt64BuffersToInt32Pass,
@@ -44,7 +43,6 @@
4443
DecomposeAtanPass,
4544
DecomposeAvgPool2dPass,
4645
DecomposeBatchNormNoStatsPass,
47-
DecomposeConvWithInt16ActivationPass,
4846
DecomposeCoshPass,
4947
DecomposeCosineSimilarityPass,
5048
DecomposeCumsumPass,
@@ -58,7 +56,6 @@
5856
DecomposeFloorDividePass,
5957
DecomposeGeluPass,
6058
DecomposeGluPass,
61-
DecomposeGroupedConvPass,
6259
DecomposeGroupNormPass,
6360
DecomposeGruPass,
6461
DecomposeIndexCopyPass,
@@ -141,7 +138,6 @@
141138
RewriteUpsamplePass,
142139
ScalarsToAttributePass,
143140
SizeAdjustInputPass,
144-
ToTosaMemoryFormatPass,
145141
UnsqueezeBeforeRepeatPass,
146142
UnsqueezeScalarPlaceholdersPass,
147143
)
@@ -157,7 +153,26 @@
157153
TosaLoweringContext,
158154
TosaSpecification,
159155
)
156+
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
157+
FuseCascadedTransposeOrPermuteOps,
158+
)
159+
from executorch.backends.transforms.fuse_cascaded_view_ops import (
160+
FuseCascadedViewOps,
161+
)
162+
from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import (
163+
FuseTransposeOrPermuteOpPairsPass,
164+
)
165+
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
166+
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
167+
)
168+
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
169+
RemovePermutesAroundElementwiseOps,
170+
)
171+
from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import (
172+
ReplaceNopTransposeOrPermuteWithViewPass,
173+
)
160174
from executorch.exir import ExportedProgram
175+
from executorch.exir.dialects._ops import ops as exir_ops
161176
from executorch.exir.pass_base import ExportPass
162177
from executorch.exir.pass_manager import PassManager
163178
from torch._export.utils import _get_shape_env_from_gm
@@ -385,9 +400,6 @@ def _tosa_pipeline(
385400
# Allow subclasses to configure pass insertions before building pipeline
386401
self._configure_pass_insertions(exported_program)
387402

388-
# Preprocessing passes
389-
self.add_pass(AnnotateOutputDimOrderPass())
390-
391403
# Node transformation passes (pre q/dq folding)
392404
self.add_passes(
393405
[
@@ -455,7 +467,6 @@ def _tosa_pipeline(
455467
DecomposeFloorDividePass(),
456468
DecomposeGeluPass(),
457469
DecomposeAddSubAlphaPass(),
458-
DecomposeGroupedConvPass(),
459470
DecomposeUnfoldToGatherPass(),
460471
DecomposeEmbeddingPass(),
461472
DecomposeIndexSelectToGatherPass(),
@@ -518,7 +529,6 @@ def _tosa_pipeline(
518529
ConvertPermuteSingletonToViewPass(),
519530
RewriteHighRankSingletonPermutePass(),
520531
FuseViewCopyTransformPass(),
521-
DecomposeConvWithInt16ActivationPass(),
522532
DecomposeSumPass(),
523533
InsertTableOpsPass(exported_program),
524534
]
@@ -532,7 +542,6 @@ def _tosa_pipeline(
532542
RewriteConvPass(exported_program),
533543
RewriteMatmulPass(),
534544
RewritePadPass(),
535-
RewriteSlicePass(),
536545
InsertConstShapesPass(),
537546
]
538547
)
@@ -542,14 +551,40 @@ def _tosa_pipeline(
542551
[
543552
CastInt64BuffersToInt32Pass(exported_program),
544553
FuseEqualPlaceholdersPass(exported_program),
554+
FuseConstantArgsPass(exported_program),
545555
FuseConsecutiveConcatShapesPass(),
546-
ToTosaMemoryFormatPass(exported_program),
547556
RemoveNoopPass(),
548557
InsertRescalePass(),
549558
InsertDataLayoutCastsPass(),
550559
]
551560
)
552561

562+
# Additional optimization passes for permutes
563+
# Fuse identity permute pairs across RESCALE ops
564+
fuse_pairs = FuseTransposeOrPermuteOpPairsPass()
565+
fuse_pairs.bypass_ops = fuse_pairs.bypass_ops | {
566+
exir_ops.backend.tosa.RESCALE.default,
567+
}
568+
569+
# Remove permutes around elementwise ops including RESCALE
570+
remove_around = RemovePermutesAroundElementwiseOps()
571+
remove_around.permutable_ops = remove_around.permutable_ops | {
572+
exir_ops.backend.tosa.RESCALE.default,
573+
}
574+
575+
self.add_passes(
576+
[
577+
remove_around,
578+
RewriteSlicePass(),
579+
fuse_pairs,
580+
ReplaceNopTransposeOrPermuteWithViewPass(),
581+
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
582+
FuseCascadedTransposeOrPermuteOps(),
583+
FuseCascadedViewOps(),
584+
InsertConstShapesPass(),
585+
]
586+
)
587+
553588
# Apply all pass insertions once after all passes are collected
554589
self._apply_pass_insertions()
555590

backends/arm/_passes/arm_pass_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,6 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value):
352352
raise RuntimeError("Invalid type")
353353

354354

355-
def get_output_dim_orders(graph_module):
356-
output_node = graph_module.graph.output_node()
357-
return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]]
358-
359-
360355
def is_nested_control_flow_graph(graph_module: GraphModule) -> bool:
361356
"""Returns True if graph_module is a nested control-flow graph."""
362357

backends/arm/_passes/rewrite_avg_pool2d_pass.py

Lines changed: 120 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,69 +7,139 @@
77

88
import torch
99
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
1014
from executorch.backends.arm.operators.operator_validation_utils import (
1115
adjust_pooling_pad_if_needed,
1216
)
1317
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass
18+
from executorch.exir.pass_base import ExportPass, PassResult
1519

1620
from .fuse_constant_ops_pass import ComputeConstantOpsAOTPass
1721

22+
_NCHW_TO_NHWC = [0, 2, 3, 1]
23+
_NHWC_TO_NCHW = [0, 3, 1, 2]
24+
1825

1926
class RewriteAvgPool2dPass(ArmPass):
20-
"""Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op."""
27+
"""Rewrite aten.avg_pool2d calls to TOSA AVG_POOL2D op with NHWC layout."""
2128

22-
# Target the original avg_pool2d operator
2329
targeted_ops = {exir_ops.edge.aten.avg_pool2d.default}
2430
_passes_required_after: Set[Type[ExportPass]] = {
2531
ComputeConstantOpsAOTPass,
2632
}
2733

28-
def call_operator(self, op, args, kwargs, meta, updated=False):
29-
30-
# Only rewrite avg_pool2d
31-
if op not in self.targeted_ops:
32-
return super().call_operator(op, args, kwargs, meta, updated)
33-
34-
x = args[0]
35-
pad_h, pad_w = args[3]
36-
# Make sure pad corresponds to TOSA
37-
pad = [pad_h, pad_w, pad_h, pad_w]
38-
39-
_, _, h, w = x.data.shape
40-
kernel_h, kernel_w = args[1]
41-
stride_h, stride_w = args[2]
42-
43-
ceil_mode = args[4] if len(args) > 4 else False
44-
45-
# Adjust padding if necessary
46-
pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
47-
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)
48-
49-
# Materialize zero-point constants
50-
in_qparams = meta.data.get("input_qparams", {})
51-
in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0
52-
# Materialize input zero-point as a scalar tensor
53-
input_zp = super().call_scalar(in_zp_val, meta)
54-
55-
out_qparams = meta.data.get("output_qparams", {})
56-
out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0
57-
# Materialize output zero-point as a scalar tensor
58-
output_zp = super().call_scalar(out_zp_val, meta)
59-
60-
# Determine accumulator dtype for AVG_POOL2D: INT32 for integer inputs, FP32 otherwise
61-
if x.data.dtype in (torch.int8, torch.int16):
62-
acc_type = torch.int32
63-
else:
64-
acc_type = torch.float32
65-
66-
tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type)
67-
68-
# Emit TOSA AVG_POOL2D with normalized args
69-
return super().call_operator(
70-
exir_ops.backend.tosa.AVG_POOL2D.default,
71-
tosa_args,
72-
{},
73-
meta,
74-
True,
34+
@staticmethod
35+
def _insert_permute(graph_module, anchor_node, input_node, perm, before=True):
36+
ctx = (
37+
graph_module.graph.inserting_before(anchor_node)
38+
if before
39+
else graph_module.graph.inserting_after(anchor_node)
7540
)
41+
with ctx:
42+
return create_node(
43+
graph=graph_module.graph,
44+
op_target=exir_ops.edge.aten.permute_copy.default,
45+
args=(input_node, perm),
46+
from_node=input_node,
47+
)
48+
49+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
50+
modified = False
51+
52+
for node in list(graph_module.graph.nodes):
53+
if node.op != "call_function" or node.target not in self.targeted_ops:
54+
continue
55+
56+
modified = True
57+
x = node.args[0]
58+
59+
pad_h, pad_w = node.args[3]
60+
pad = [pad_h, pad_w, pad_h, pad_w]
61+
62+
input_fake = get_first_fake_tensor(x)
63+
_, _, h, w = input_fake.shape
64+
kernel_h, kernel_w = node.args[1]
65+
stride_h, stride_w = node.args[2]
66+
67+
ceil_mode = node.args[4] if len(node.args) > 4 else False
68+
69+
pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode)
70+
pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode)
71+
72+
# Determine zero-points and accumulator type
73+
in_qparams = node.meta.get("input_qparams", {})
74+
in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0
75+
76+
out_qparams = node.meta.get("output_qparams", {})
77+
out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0
78+
79+
if input_fake.dtype in (torch.int8, torch.int16):
80+
acc_type = torch.int32
81+
else:
82+
acc_type = torch.float32
83+
84+
# Insert NCHW → NHWC permute on input
85+
x_permuted = self._insert_permute(
86+
graph_module, node, x, _NCHW_TO_NHWC, before=True
87+
)
88+
89+
# Materialize zp scalars as graph constants using aten.full with
90+
# explicit dtype matching the input tensor. This ensures the
91+
# pre-computed buffer placeholders carry the correct type for
92+
# INT-only TOSA profiles (avoids defaulting to float32).
93+
zp_kwargs = {"dtype": input_fake.dtype, "device": input_fake.device}
94+
with graph_module.graph.inserting_before(node):
95+
input_zp_node = create_node(
96+
graph=graph_module.graph,
97+
op_target=exir_ops.edge.aten.full.default,
98+
args=((1,), in_zp_val),
99+
kwargs=zp_kwargs,
100+
from_node=node,
101+
)
102+
output_zp_node = create_node(
103+
graph=graph_module.graph,
104+
op_target=exir_ops.edge.aten.full.default,
105+
args=((1,), out_zp_val),
106+
kwargs=zp_kwargs,
107+
from_node=node,
108+
)
109+
110+
kernel = list(node.args[1])
111+
stride = list(node.args[2])
112+
113+
tosa_args = (x_permuted, input_zp_node, output_zp_node, kernel, stride, pad, acc_type)
114+
115+
# Create TOSA AVG_POOL2D node
116+
with graph_module.graph.inserting_after(node):
117+
tosa_op = create_node(
118+
graph=graph_module.graph,
119+
op_target=exir_ops.backend.tosa.AVG_POOL2D.default,
120+
args=tosa_args,
121+
from_node=node,
122+
inherit_qparams=True,
123+
)
124+
125+
# Compute correct NHWC FakeTensor
126+
input_fake_nhwc = input_fake.permute(_NCHW_TO_NHWC)
127+
input_zp_fake = torch.tensor(in_zp_val, dtype=input_fake.dtype)
128+
output_zp_fake = torch.tensor(out_zp_val, dtype=input_fake.dtype)
129+
tosa_node_fake = exir_ops.backend.tosa.AVG_POOL2D.default(
130+
input_fake_nhwc, input_zp_fake, output_zp_fake, kernel, stride, pad, acc_type
131+
)
132+
tosa_op.meta["val"] = tosa_node_fake
133+
134+
# Insert NHWC → NCHW permute on output
135+
output_permute = self._insert_permute(
136+
graph_module, tosa_op, tosa_op, _NHWC_TO_NCHW, before=False
137+
)
138+
139+
node.replace_all_uses_with(output_permute)
140+
graph_module.graph.erase_node(node)
141+
142+
if modified:
143+
graph_module.recompile()
144+
graph_module = super().call(graph_module).graph_module
145+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)