diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 3700c11e3da..3f2cde5adef 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -131,6 +131,7 @@ from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa +from .normalize_delegate_io_layout_pass import NormalizeDelegateIOLayoutPass # noqa from .normalize_index_put_bool_index_tensor_pass import ( # noqa NormalizeIndexPutBoolIndexTensorPass, ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 205e6dce7cc..e39d8d605f4 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -12,7 +12,6 @@ from executorch.backends.arm._passes import ( AccumulateIndexPutPass, - AnnotateOutputDimOrderPass, BroadcastArgsPass, CanonicalizeGatherPass, CastInt64BuffersToInt32Pass, @@ -44,7 +43,6 @@ DecomposeAtanPass, DecomposeAvgPool2dPass, DecomposeBatchNormNoStatsPass, - DecomposeConvWithInt16ActivationPass, DecomposeCoshPass, DecomposeCosineSimilarityPass, DecomposeCumsumPass, @@ -117,6 +115,7 @@ InsertTableOpsPass, MatchArgDtypePass, MatchArgRanksPass, + NormalizeDelegateIOLayoutPass, NormalizeIndexPutBoolIndexTensorPass, NormalizeIndexPutNoneIndicesPass, NormalizeWhileInitialArgsPass, @@ -142,7 +141,6 @@ RewriteUpsamplePass, ScalarsToAttributePass, SizeAdjustInputPass, - ToTosaMemoryFormatPass, UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) @@ -158,6 +156,16 @@ TosaLoweringContext, TosaSpecification, ) +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( + FuseCascadedTransposeOrPermuteOps, +) +from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) + +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps, +) from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager @@ -386,12 +394,10 @@ def _tosa_pipeline( # Allow subclasses to configure pass insertions before building pipeline self._configure_pass_insertions(exported_program) - # Preprocessing passes - self.add_pass(AnnotateOutputDimOrderPass()) - # Node transformation passes (pre q/dq folding) self.add_passes( [ + NormalizeDelegateIOLayoutPass(exported_program), FuseQuantizedActivationPass(), RewriteBoolToFp32CastViaInt8Pass(), CanonicalizeGatherPass(), @@ -516,12 +522,9 @@ def _tosa_pipeline( ConvertSqueezesToViewPass(), CastToInt32Pass(), BroadcastArgsPass(), - ConvertPermuteSingletonToViewPass(), - RewriteHighRankSingletonPermutePass(), - FuseViewCopyTransformPass(), - DecomposeConvWithInt16ActivationPass(), DecomposeSumPass(), InsertTableOpsPass(exported_program), + RemoveNoopPass(), ] ) @@ -534,6 +537,12 @@ def _tosa_pipeline( RewriteMatmulPass(), RewritePadPass(), RewriteSlicePass(), + FuseViewCopyTransformPass(), + RemovePermutesAroundElementwiseOps(), + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), + FuseCascadedTransposeOrPermuteOps(), + ConvertPermuteSingletonToViewPass(), + RewriteHighRankSingletonPermutePass(), InsertConstShapesPass(), ] ) @@ -544,7 +553,6 @@ def _tosa_pipeline( CastInt64BuffersToInt32Pass(exported_program), FuseEqualPlaceholdersPass(exported_program), FuseConsecutiveConcatShapesPass(), - ToTosaMemoryFormatPass(exported_program), EnsureUniqueOutputNodesPass(), RemoveNoopPass(), InsertRescalePass(), diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 7f9b47d3e01..3b1b2894f4b 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -397,3 +397,12 @@ def get_cond_while_submodules_nested( } # collect cond/while submodules (using mapping indices) return _get_control_flow_submodules(graph_module, mapping) + + +def to_2tuple(value): + """Normalizes scalars, and 1-element sequences to a tuple of length 2.""" + if isinstance(value, int): + return (value, value) + if len(value) == 1: + return (value[0], value[0]) + return tuple(value) diff --git a/backends/arm/_passes/conv1d_unsqueeze_pass.py b/backends/arm/_passes/conv1d_unsqueeze_pass.py index 58c3c0c35a2..cf1e884e05b 100644 --- a/backends/arm/_passes/conv1d_unsqueeze_pass.py +++ b/backends/arm/_passes/conv1d_unsqueeze_pass.py @@ -47,7 +47,7 @@ def call_operator(self, op, args, kwargs, meta): x_meta.data["output_qparams"] = {} x = args[0] - x_unsqueezed_shape = list(x.data.shape) + [1] + x_unsqueezed_shape = list(x.data.shape[:-1]) + [1] + [x.data.shape[-1]] x = super().call_operator( exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), @@ -61,7 +61,7 @@ def call_operator(self, op, args, kwargs, meta): w_meta.data["output_qparams"] = {} w = args[1] - w_unsqueezed_shape = list(w.data.shape) + [1] + w_unsqueezed_shape = list(w.data.shape[:-1]) + [1] + [w.data.shape[-1]] w = super().call_operator( exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), @@ -74,11 +74,11 @@ def call_operator(self, op, args, kwargs, meta): x, w, args[2], - args[3] + [1], # stride - args[4] + [0], # padding - args[5] + [1], # dilation + [1] + args[3], # stride + [0] + args[4], # padding + [1] + args[5], # dilation args[6], - args[7] + [0], + [0] + args[7], args[8], ) x = super().call_operator( @@ -88,7 +88,7 @@ def call_operator(self, op, args, kwargs, meta): x_squeezed_meta = meta.copy() x_squeezed_meta.data["input_qparams"] = {} x_squeezed_meta.data["output_qparams"] = {} - x_squeezed_shape = list(x.data.shape)[:-1] + x_squeezed_shape = list(x.data.shape[:-2]) + [x.data.shape[-1]] x = super().call_operator( exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), diff --git a/backends/arm/_passes/normalize_delegate_io_layout_pass.py b/backends/arm/_passes/normalize_delegate_io_layout_pass.py new file mode 100644 index 00000000000..d1b1d964b87 --- /dev/null +++ b/backends/arm/_passes/normalize_delegate_io_layout_pass.py @@ -0,0 +1,137 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, + is_param_node, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class NormalizeDelegateIOLayoutPass(ArmPass): + """Adjust delegated boundary tensor shapes and insert permutes at I/O.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.exported_program = exported_program + + @staticmethod + def _inverse_permutation(perm: tuple[int, ...]) -> tuple[int, ...]: + inverse = [0] * len(perm) + for idx, axis in enumerate(perm): + inverse[axis] = idx + return tuple(inverse) + + @staticmethod + def _permute_shape(shape: torch.Size, perm: tuple[int, ...]) -> tuple[int, ...]: + return tuple(shape[axis] for axis in perm) + + @staticmethod + def _is_identity_dim_order(dim_order: tuple[int, ...]) -> bool: + return dim_order == tuple(range(len(dim_order))) + + def _normalize_input_layout(self, graph_module: torch.fx.GraphModule) -> bool: + modified = False + for node in graph_module.graph.nodes: + if node.op != "placeholder" or is_param_node(self.exported_program, node): + continue + + input_fake = get_first_fake_tensor(node) + dim_order = input_fake.dim_order() + if self._is_identity_dim_order(dim_order): + continue + + boundary_shape = self._permute_shape(input_fake.shape, dim_order) + node.meta["val"] = input_fake.reshape(boundary_shape) + + transpose_perm = self._inverse_permutation(dim_order) + with graph_module.graph.inserting_after(node): + permute_node = create_node( + graph_module.graph, + exir_ops.edge.aten.permute_copy.default, + args=(node, list(transpose_perm)), + from_node=node, + ) + permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default( + node.meta["val"], list(transpose_perm) + ) + + users = [user for user in node.users if user != permute_node] + for user in users: + user.replace_input_with(node, permute_node) + + modified = True + + return modified + + def _rewrite_output_arg( + self, arg: Any, graph_module: torch.fx.GraphModule + ) -> tuple[Any, bool]: + if isinstance(arg, torch.fx.Node): + output_fake = get_first_fake_tensor(arg) + dim_order = output_fake.dim_order() + if self._is_identity_dim_order(dim_order): + return arg, False + + with graph_module.graph.inserting_after(arg): + permute_node = create_node( + graph_module.graph, + exir_ops.edge.aten.permute_copy.default, + args=(arg, list(dim_order)), + from_node=arg, + ) + permute_node.meta["val"] = exir_ops.edge.aten.permute_copy.default( + output_fake, list(dim_order) + ) + + return permute_node, True + + if isinstance(arg, tuple): + modified = False + rewritten = [] + for item in arg: + new_item, item_modified = self._rewrite_output_arg(item, graph_module) + rewritten.append(new_item) + modified = modified or item_modified + return tuple(rewritten), modified + + if isinstance(arg, list): + modified = False + rewritten = [] + for item in arg: + new_item, item_modified = self._rewrite_output_arg(item, graph_module) + rewritten.append(new_item) + modified = modified or item_modified + return rewritten, modified + + return arg, False + + def _normalize_output_layout(self, graph_module: torch.fx.GraphModule) -> bool: + output_node = graph_module.graph.output_node() + rewritten_outputs, modified = self._rewrite_output_arg( + output_node.args[0], graph_module + ) + if modified: + output_node.args = (rewritten_outputs,) + return modified + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = self._normalize_input_layout(graph_module) + modified = self._normalize_output_layout(graph_module) or modified + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/rewrite_avg_pool2d_pass.py b/backends/arm/_passes/rewrite_avg_pool2d_pass.py index 2f71bdda4a2..36a58505c3a 100644 --- a/backends/arm/_passes/rewrite_avg_pool2d_pass.py +++ b/backends/arm/_passes/rewrite_avg_pool2d_pass.py @@ -7,6 +7,8 @@ import torch from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import to_2tuple +from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) @@ -32,19 +34,25 @@ def call_operator(self, op, args, kwargs, meta, updated=False): return super().call_operator(op, args, kwargs, meta, updated) x = args[0] - pad_h, pad_w = args[3] + kernel = to_2tuple(args[1]) + + stride = to_2tuple(args[2]) if len(args) > 2 else () + if not stride: + stride = kernel # default to kernel_size + + pad_h, pad_w = to_2tuple(args[3]) if len(args) > 3 else (0, 0) # Make sure pad corresponds to TOSA pad = [pad_h, pad_w, pad_h, pad_w] - _, _, h, w = x.data.shape - kernel_h, kernel_w = args[1] - stride_h, stride_w = args[2] - ceil_mode = args[4] if len(args) > 4 else False # Adjust padding if necessary - pad[1] = adjust_pooling_pad_if_needed(h, kernel_h, stride_h, pad[1], ceil_mode) - pad[3] = adjust_pooling_pad_if_needed(w, kernel_w, stride_w, pad[3], ceil_mode) + pad[1] = adjust_pooling_pad_if_needed( + x.data.shape[2], kernel[0], stride[0], pad[1], ceil_mode + ) + pad[3] = adjust_pooling_pad_if_needed( + x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode + ) # Materialize zero-point constants in_qparams = meta.data.get("input_qparams", {}) @@ -63,13 +71,36 @@ def call_operator(self, op, args, kwargs, meta, updated=False): else: acc_type = torch.float32 - tosa_args = (args[0], input_zp, output_zp, *args[1:3], pad, acc_type) + pre_permute = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (x, list(NHWC_ORDER)), + {}, + meta, + updated=True, + ) + + tosa_args = ( + pre_permute, + input_zp, + output_zp, + list(kernel), + list(stride), + pad, + acc_type, + ) # Emit TOSA AVG_POOL2D with normalized args - return super().call_operator( + tosa_avg_pool = super().call_operator( exir_ops.backend.tosa.AVG_POOL2D.default, tosa_args, {}, meta, True, ) + return super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (tosa_avg_pool, list(NHWC_INVERSE_ORDER)), + {}, + meta, + updated=True, + ) diff --git a/backends/arm/_passes/rewrite_conv_pass.py b/backends/arm/_passes/rewrite_conv_pass.py index 8244dc2558b..27565e93452 100644 --- a/backends/arm/_passes/rewrite_conv_pass.py +++ b/backends/arm/_passes/rewrite_conv_pass.py @@ -12,10 +12,10 @@ from executorch.backends.arm._passes.arm_pass_utils import ( create_node, expand_around_channel, + get_constant_placeholder_kind, get_first_fake_tensor, get_param_tensor, - is_buffer, - is_param, + is_persistent_buffer, ) from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, @@ -24,7 +24,14 @@ from executorch.backends.arm._passes.symbolic_value_range import ( evaluate_symbolic_expr_values, ) -from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER +from executorch.backends.arm.constants import ( + HWCM_ORDER, + NHWC_INVERSE_ORDER, + NHWC_ORDER, + ODHWI_INVERSE_ORDER, + ODHWI_ORDER, + OHWI_ORDER, +) from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import get_context_shape_env from executorch.backends.transforms.utils import create_constant_placeholder @@ -132,54 +139,6 @@ def _is_conv3d(self, rank, groups) -> bool: return True return False - def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None: - """Reshape the weights for depthwise convolution such that when - serialized to TOSA, the weights are in the format [H, W, in_channels, - m_length] where m_length is the number of output channels per input - channel. - """ - weight_tensor = get_param_tensor(self.exported_program, weight_node) # type: ignore[arg-type] - if weight_tensor is None: - raise RuntimeError( - f"Weight node {weight_node.name} is not a parameter or buffer" - ) - - reshaped_weight_tensor = ( - weight_tensor.permute(HWCM_ORDER) - .reshape( - weight_tensor.shape[2], - weight_tensor.shape[3], - in_channels, - weight_tensor.shape[0] // in_channels, - ) - .permute(NHWC_INVERSE_ORDER) - ) - - if is_buffer(self.exported_program, weight_node): - param_name = self.exported_program.graph_signature.inputs_to_buffers[ - weight_node.name - ] - reshaped_weight_tensor = torch.nn.Buffer(reshaped_weight_tensor) - elif is_param(self.exported_program, weight_node): - param_name = self.exported_program.graph_signature.inputs_to_parameters[ - weight_node.name - ] - reshaped_weight_tensor = torch.nn.Parameter( - reshaped_weight_tensor, requires_grad=False - ) - else: - raise RuntimeError( - f"Weight node {weight_node.name} is neither a parameter nor a buffer" - ) - - self.exported_program.state_dict[param_name] = reshaped_weight_tensor - weight_node.meta["val"] = weight_node.meta["val"].reshape( - weight_tensor.shape[2], - weight_tensor.shape[0] // in_channels, - weight_tensor.shape[3], - in_channels, - ) - def _add_bias( self, graph_module: torch.fx.GraphModule, @@ -203,14 +162,61 @@ def _add_bias( persistent_buffer=True, name=f"{node.name}_bias", ) - if node.all_input_nodes[0].meta["val"].dtype == torch.int16: - bias_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + self._mark_bias_as_int48_if_needed(node, bias_node) node.update_arg(2, bias_node) return bias_node + def _rewrite_weight( + self, + graph_module: torch.fx.GraphModule, + weight_node: torch.fx.Node, + conv_node: torch.fx.Node, + permute_dims: tuple[int, ...], + name_suffix: str, + reshape_dims: tuple[int, ...] | None = None, + ) -> torch.fx.Node: + """Create a convolution-local rewritten weight placeholder.""" + weight_tensor = get_param_tensor(self.exported_program, weight_node) # type: ignore[arg-type] + if weight_tensor is None: + raise RuntimeError( + f"Weight node {weight_node.name} is not a parameter or buffer" + ) + + rewritten_weight = weight_tensor.permute(permute_dims) + if reshape_dims is not None: + rewritten_weight = rewritten_weight.reshape(*reshape_dims) + rewritten_weight = rewritten_weight.contiguous() + kind = get_constant_placeholder_kind(self.exported_program, weight_node) + persistent_buffer = is_persistent_buffer(self.exported_program, weight_node) + + with graph_module.graph.inserting_after(weight_node): + rewritten_weight_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + name=f"{conv_node.name}_weight_{name_suffix}", + kind=kind, + data=rewritten_weight, + persistent_buffer=persistent_buffer, + ) + if special_dtype := weight_node.meta.get(TosaSpecialDtype.meta_key()): + rewritten_weight_node.meta[TosaSpecialDtype.meta_key()] = special_dtype + return rewritten_weight_node + def _is_quantized_conv(self, node: torch.fx.Node) -> bool: return bool(node.meta.get("input_qparams", {})) + def _is_int16_activation_conv(self, node: torch.fx.Node) -> bool: + input_qparams = node.meta.get("input_qparams", {}) + if 0 in input_qparams: + return input_qparams[0].dtype == torch.int16 + return get_first_fake_tensor(node.all_input_nodes[0]).dtype == torch.int16 + + def _mark_bias_as_int48_if_needed( + self, node: torch.fx.Node, bias_node: torch.fx.Node + ) -> None: + if self._is_int16_activation_conv(node): + bias_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + def _get_effective_output_qparams(self, node: torch.fx.Node): """Return the quantized output domain for a conv node. @@ -239,7 +245,13 @@ def _get_effective_output_qparams(self, node: torch.fx.Node): return get_output_qparams(node) - def insert_output_rescale(self, graph_module, source_node, conv_node): + def insert_output_rescale( + self, + graph_module, + source_node, + conv_node, + conv_fake_tensor: torch.Tensor, + ): input_qparams = get_input_qparams(source_node) output_qparams = self._get_effective_output_qparams(source_node)[0] weight_qparams = input_qparams[1] @@ -271,7 +283,65 @@ def insert_output_rescale(self, graph_module, source_node, conv_node): ), from_node=source_node, ) - return rescale_node + rescale_fake_tensor = exir_ops.backend.tosa.RESCALE.default( + conv_fake_tensor, + output_qparams.dtype, + post_conv2d_scale, + 0, + output_qparams.get_zp_per_tensor(), + ) + return rescale_node, rescale_fake_tensor + + def insert_identity_int32_rescale( + self, + graph_module, + source_node, + conv_node, + conv_fake_tensor: torch.Tensor, + ): + with graph_module.graph.inserting_after(conv_node): + rescale_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + args=( + conv_node, + torch.int32, + [1.0], + 0, + 0, + ), + from_node=source_node, + ) + rescale_fake_tensor = exir_ops.backend.tosa.RESCALE.default( + conv_fake_tensor, + torch.int32, + [1.0], + 0, + 0, + ) + return rescale_node, rescale_fake_tensor + + def _has_int32_rescale_user(self, node: torch.fx.Node) -> bool: + for user in node.users: + if ( + user.op == "call_function" + and user.target == exir_ops.backend.tosa.RESCALE.default + and len(user.args) > 1 + and user.args[1] == torch.int32 + ): + return True + if ( + user.op == "call_function" + and user.target == exir_ops.edge.aten.permute_copy.default + ): + for inner_user in user.users: + if ( + inner_user.op == "call_function" + and inner_user.target == exir_ops.backend.tosa.RESCALE.default + and inner_user.args[1] == torch.int32 + ): + return True + return False def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 modified = False @@ -310,8 +380,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 has_bias = bias is not None if not has_bias: bias = self._add_bias(graph_module, node, weight) + elif isinstance(bias, torch.fx.Node): + self._mark_bias_as_int48_if_needed(node, bias) conv_args: tuple[Any, ...] + input_tensor_for_tosa_fake: torch.Tensor = input_fake_tensor + pre_permute_dims: tuple[int, ...] + post_permute_dims: tuple[int, ...] if transposed: if spatial_rank != 2: raise RuntimeError( @@ -331,6 +406,27 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 -pad_list[1] + output_padding_list[1], ] target_op = exir_ops.backend.tosa.TRANSPOSE_CONV2D.default + pre_permute_dims = NHWC_ORDER + post_permute_dims = NHWC_INVERSE_ORDER + with graph_module.graph.inserting_before(node): + x = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(x, list(pre_permute_dims)), + from_node=node, + ) + x.meta["val"] = exir_ops.edge.aten.permute_copy.default( + input_fake_tensor, list(pre_permute_dims) + ) + weight = self._rewrite_weight( + graph_module, + weight, + node, + permute_dims=OHWI_ORDER, + name_suffix="ohwi", + ) + input_tensor_for_tosa_fake = input_fake_tensor.permute(pre_permute_dims) + weight_fake_tensor = get_first_fake_tensor(weight) conv_args = ( x, weight, @@ -360,14 +456,83 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 if self._is_conv3d(len(input_shape), group): target_op = exir_ops.backend.tosa.CONV3D.default + pre_permute_dims = ODHWI_ORDER + post_permute_dims = ODHWI_INVERSE_ORDER + with graph_module.graph.inserting_before(node): + x = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(x, list(pre_permute_dims)), + from_node=node, + ) + x.meta["val"] = exir_ops.edge.aten.permute_copy.default( + input_fake_tensor, list(pre_permute_dims) + ) + weight = self._rewrite_weight( + graph_module, + weight, + node, + permute_dims=ODHWI_ORDER, + name_suffix="odhwi", + ) + input_tensor_for_tosa_fake = input_fake_tensor.permute( + pre_permute_dims + ) + weight_fake_tensor = get_first_fake_tensor(weight) elif self._is_depthwise_conv2d(node): target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default - # If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them. - if all(user.target != target_op for user in weight.users): - self._reshape_weights(weight, input_fake_tensor.shape[1]) + pre_permute_dims = NHWC_ORDER + post_permute_dims = NHWC_INVERSE_ORDER + with graph_module.graph.inserting_before(node): + x = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(x, list(pre_permute_dims)), + from_node=node, + ) + x.meta["val"] = exir_ops.edge.aten.permute_copy.default( + input_fake_tensor, list(pre_permute_dims) + ) + kh, kw = weight_shape[2], weight_shape[3] + in_channels = input_fake_tensor.shape[1] + m_length = weight_shape[0] // in_channels + weight = self._rewrite_weight( + graph_module, + weight, + node, + permute_dims=HWCM_ORDER, + name_suffix="hwicm", + reshape_dims=(kh, kw, in_channels, m_length), + ) + input_tensor_for_tosa_fake = input_fake_tensor.permute( + pre_permute_dims + ) weight_fake_tensor = get_first_fake_tensor(weight) else: target_op = exir_ops.backend.tosa.CONV2D.default + pre_permute_dims = NHWC_ORDER + post_permute_dims = NHWC_INVERSE_ORDER + with graph_module.graph.inserting_before(node): + x = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(x, list(pre_permute_dims)), + from_node=node, + ) + x.meta["val"] = exir_ops.edge.aten.permute_copy.default( + input_fake_tensor, list(pre_permute_dims) + ) + weight = self._rewrite_weight( + graph_module, + weight, + node, + permute_dims=NHWC_ORDER, + name_suffix="ohwi", + ) + input_tensor_for_tosa_fake = input_fake_tensor.permute( + pre_permute_dims + ) + weight_fake_tensor = get_first_fake_tensor(weight) conv_args = ( x, @@ -388,33 +553,64 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 ) bias_fake_tensor = get_first_fake_tensor(bias) if bias else None tosa_node_fake_tensor = target_op( - input_fake_tensor, + input_tensor_for_tosa_fake, weight_fake_tensor, bias_fake_tensor, *conv_args[3:], ) + tosa_op.meta["val"] = tosa_node_fake_tensor + node_replacement: torch.fx.Node = tosa_op + node_replacement_fake_tensor = tosa_node_fake_tensor if ( tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int8 ): - output_rescale = self.insert_output_rescale(graph_module, node, tosa_op) - node.replace_all_uses_with(output_rescale) + output_rescale, output_rescale_fake = self.insert_output_rescale( + graph_module, node, tosa_op, tosa_node_fake_tensor + ) + node_replacement = output_rescale + node_replacement_fake_tensor = output_rescale_fake elif ( tosa_node_fake_tensor.dtype == torch.int32 and input_fake_tensor.dtype == torch.int16 ): - has_bias = len(node.meta["input_qparams"]) > 2 - if not has_bias: - output_rescale = self.insert_output_rescale( - graph_module, node, tosa_op + # Explicit layout paths require a post-conv permute, which does + # not support INT48. Always rescale before post-permute. + if self._has_int32_rescale_user(node): + output_rescale, output_rescale_fake = ( + self.insert_identity_int32_rescale( + graph_module, node, tosa_op, tosa_node_fake_tensor + ) ) - node.replace_all_uses_with(output_rescale) else: - node.replace_all_uses_with(tosa_op) + output_rescale, output_rescale_fake = self.insert_output_rescale( + graph_module, node, tosa_op, tosa_node_fake_tensor + ) + node_replacement = output_rescale + node_replacement_fake_tensor = output_rescale_fake + tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 - else: - node.replace_all_uses_with(tosa_op) + + if post_permute_dims is None: + raise RuntimeError("Expected post permute dims for explicit layout") + post_permute_input = node_replacement + with graph_module.graph.inserting_after(node_replacement): + node_replacement = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(node_replacement, list(post_permute_dims)), + from_node=node, + ) + if special_dtype := post_permute_input.meta.get( + TosaSpecialDtype.meta_key() + ): + node_replacement.meta[TosaSpecialDtype.meta_key()] = special_dtype + node_replacement.meta["val"] = exir_ops.edge.aten.permute_copy.default( + node_replacement_fake_tensor, list(post_permute_dims) + ) + + node.replace_all_uses_with(node_replacement) graph_module.graph.erase_node(node) diff --git a/backends/arm/_passes/rewrite_max_pool2d_pass.py b/backends/arm/_passes/rewrite_max_pool2d_pass.py index 123d21eda1f..8a59f2bd4ac 100644 --- a/backends/arm/_passes/rewrite_max_pool2d_pass.py +++ b/backends/arm/_passes/rewrite_max_pool2d_pass.py @@ -6,6 +6,8 @@ from typing import Set, Type from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import to_2tuple +from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER from executorch.backends.arm.operators.operator_validation_utils import ( adjust_pooling_pad_if_needed, ) @@ -15,14 +17,6 @@ edge_max_pool2d_ops = (exir_ops.edge.aten.max_pool2d.default,) -def _to_2tuple(value): - if isinstance(value, int): - return (value, value) - if len(value) == 1: - return (value[0], value[0]) - return tuple(value) - - class RewriteMaxPool2dPass(ArmPass): """Rewrite max_pool2d ops to TOSA MAX_POOL2D.""" @@ -33,19 +27,23 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) x = args[0] - kernel = _to_2tuple(args[1]) + kernel = args[1] + stride = to_2tuple(args[2]) if len(args) > 2 else () + if not stride: + stride = kernel # default to kernel_size - if len(args) > 2 and args[2] is not None and len(args[2]) > 0: - stride = _to_2tuple(args[2]) - else: - stride = kernel - - padding = _to_2tuple(args[3]) if len(args) > 3 else (0, 0) - dilation = _to_2tuple(args[4]) if len(args) > 4 else (1, 1) + padding = to_2tuple(args[3]) if len(args) > 3 else (0, 0) + dilation = to_2tuple(args[4]) if len(args) > 4 else (1, 1) ceil_mode = args[5] if len(args) > 5 else False - if dilation != (1, 1): - return super().call_operator(op, args, kwargs, meta) + if not dilation == (1, 1): + from executorch.backends.arm._passes.decompose_maxpool2d_with_dilation_pass import ( + DecomposeMaxPool2dPass, + ) + + raise RuntimeError( + f"Dilation > 1 is not supported for tosa.MAX_POOL2D, has {DecomposeMaxPool2dPass.__name__} run?" + ) # TOSA MAX_POOL2D pad order is [top, bottom, left, right] pad = [padding[0], padding[0], padding[1], padding[1]] @@ -56,9 +54,28 @@ def call_operator(self, op, args, kwargs, meta): x.data.shape[3], kernel[1], stride[1], pad[3], ceil_mode ) - return super().call_operator( + pre_permute = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (x, list(NHWC_ORDER)), + {}, + meta, + updated=True, + ) + tosa_pool = super().call_operator( exir_ops.backend.tosa.MAX_POOL2D.default, - (x, list(kernel), list(stride), pad), + ( + pre_permute, + list(kernel), + list(stride), + pad, + ), + {}, + meta, + updated=True, + ) + return super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (tosa_pool, list(NHWC_INVERSE_ORDER)), {}, meta, updated=True, diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index 4a864125faf..da336d0dde3 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -37,6 +37,8 @@ class RewriteUpsamplePass(ArmPass): ) _passes_required_after: Set[Type[ExportPass]] = set() + _NHWC_ORDER = (0, 2, 3, 1) + _NHWC_INVERSE_ORDER = (0, 3, 1, 2) @staticmethod def get_resize_parameters_1d( @@ -188,17 +190,34 @@ def call(self, graph_module): from_node=node, ) + pre_permute = create_node( + graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(x, list(self._NHWC_ORDER)), + from_node=node, + ) + pre_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default( + get_first_fake_tensor(x), list(self._NHWC_ORDER) + ) + tosa_resize_node = create_node( graph_module.graph, op_target=exir_ops.backend.tosa.RESIZE.default, - args=(x, scale, offset, border), + args=(pre_permute, scale, offset, border), kwargs={"resize_mode": resize_mode}, from_node=node, inherit_qparams=True, ) - node.replace_all_uses_with(tosa_resize_node) - graph_module.graph.erase_node(node) + tosa_resize_node.meta["val"] = exir_ops.backend.tosa.RESIZE.default( + pre_permute.meta["val"], + scale if isinstance(scale, list) else scale.args[0], + offset if isinstance(offset, list) else offset.args[0], + border if isinstance(border, list) else border.args[0], + resize_mode=resize_mode, + ) input_dtype = get_first_fake_tensor(x).dtype + node_replacement = tosa_resize_node + node_replacement_fake = tosa_resize_node.meta["val"] if ( input_dtype == torch.int8 or input_dtype == torch.int16 ) and resize_mode == "bilinear": @@ -208,8 +227,10 @@ def call(self, graph_module): rescale_node = create_node( graph_module.graph, exir_ops.backend.tosa.RESCALE.default, + from_node=node, ) - tosa_resize_node.replace_all_uses_with(rescale_node) + rescale_node.meta["val"] = node_replacement_fake + if input_dtype == torch.int16: tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = ( TosaSpecialDtype.INT48 @@ -222,6 +243,24 @@ def call(self, graph_module): 0, # zero point 0, # zero point ) + node_replacement = rescale_node + node_replacement_fake = exir_ops.backend.tosa.RESCALE.default( + tosa_resize_node.meta["val"], output_dtype, [output_scale], 0, 0 + ) + + with graph_module.graph.inserting_after(node_replacement): + post_permute = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.permute_copy.default, + args=(node_replacement, list(self._NHWC_INVERSE_ORDER)), + from_node=node, + ) + post_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default( + node_replacement_fake, + list(self._NHWC_INVERSE_ORDER), + ) + node.replace_all_uses_with(post_permute) + graph_module.graph.erase_node(node) if modified: graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/constants.py b/backends/arm/constants.py index c0d9c4504f7..c588d684964 100644 --- a/backends/arm/constants.py +++ b/backends/arm/constants.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,6 +41,10 @@ NNCHW_ORDER: Final = (0, 1, 2, 3, 4) NNNCHW_ORDER: Final = (0, 1, 2, 3, 4, 5) +OHWI_ORDER: Final = (1, 2, 3, 0) +ODHWI_ORDER: Final = (0, 2, 3, 4, 1) +ODHWI_INVERSE_ORDER: Final = (0, 4, 1, 2, 3) + HWCM_ORDER: Final = (2, 3, 0, 1) MAX_RANK: Final = 6 diff --git a/backends/arm/operators/op_tosa_shapes.py b/backends/arm/operators/op_tosa_shapes.py index 7e426a1da4e..a3087108d2a 100644 --- a/backends/arm/operators/op_tosa_shapes.py +++ b/backends/arm/operators/op_tosa_shapes.py @@ -34,8 +34,7 @@ def define_node( ) -> None: shape_input = inputs[0].special rank = len(shape_input) - tosa_dim_order = output.dim_order - vals = tosa_shape(node.meta["val"], tosa_dim_order) + vals = tosa_shape(node.meta["val"], output.dim_order) tosa_graph = cast(ts.TosaSerializer, tosa_graph) tosa_graph.addConst( [ diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index e5f522d6e2e..3ae1c233c2b 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -47,6 +47,15 @@ def _tensor_to_numpy_with_dim_order( return np.transpose(np_tensor, dim_order) +def _prepare_const_values_for_tosa_dtype( + values: np.ndarray, tosa_dtype: ts.DType +) -> np.ndarray: + """Normalize constant storage to the expected TOSA serializer dtype.""" + if tosa_dtype == ts.DType.INT48 and values.dtype != np.int64: + return values.astype(np.int64) + return values + + def process_call_function( node: torch.fx.Node, tosa_graph: Any, @@ -140,6 +149,9 @@ def process_inputs_to_parameters( parameter_values = _tensor_to_numpy_with_dim_order( parameter_data, tosa_arg.dim_order # type: ignore[arg-type] ) + parameter_values = _prepare_const_values_for_tosa_dtype( + parameter_values, tosa_arg.dtype + ) tosa_graph.addConst( parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name @@ -168,6 +180,7 @@ def process_inputs_to_buffers( f"{type(buffer_data).__name__}" ) buffer_values = _tensor_to_numpy_with_dim_order(buffer_data, tosa_arg.dim_order) # type: ignore[arg-type] + buffer_values = _prepare_const_values_for_tosa_dtype(buffer_values, tosa_arg.dtype) tosa_graph.addConst( buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name @@ -192,6 +205,7 @@ def process_inputs_to_lifted_tensor_constants( tensor, # type: ignore[arg-type] tosa_arg.dim_order, # type: ignore[arg-type] ) + tensor_values = _prepare_const_values_for_tosa_dtype(tensor_values, tosa_arg.dtype) tosa_graph.addConst( tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name diff --git a/backends/arm/test/misc/test_const_shape.py b/backends/arm/test/misc/test_const_shape.py index 2694dc6ea97..b976f1997c9 100644 --- a/backends/arm/test/misc/test_const_shape.py +++ b/backends/arm/test/misc/test_const_shape.py @@ -6,21 +6,10 @@ from typing import Set, Type import executorch.backends.arm.tosa.dialect # noqa: F401 -import pytest import torch -import tosa_serializer as ts from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.backends.arm._passes.to_tosa_memory_format_pass import ( - ToTosaMemoryFormatPass, -) -from executorch.backends.arm.operators.node_visitor import get_node_visitors -from executorch.backends.arm.process_node import process_call_function from executorch.backends.arm.tosa.mapping import TosaSpecialDtype -from executorch.backends.arm.tosa.specification import ( - TosaLoweringContext, - TosaSpecification, -) -from executorch.backends.test.graph_builder import GraphBuilder + from executorch.exir import to_edge from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -67,76 +56,3 @@ def forward(self, x): assert const_shape_nodes for n in const_shape_nodes: assert n.meta[TosaSpecialDtype.meta_key()] == TosaSpecialDtype.SHAPE - - -def _graph_module_with_unused_const_shape(): - with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): - builder = GraphBuilder() - builder.call_operator(exir_ops.backend.tosa.CONST_SHAPE.default, ([1],)) - live_const = builder.call_operator( - exir_ops.backend.tosa.CONST_SHAPE.default, ([3],) - ) - builder.output([live_const]) - graph_module = ExportPass().call(builder.get_graph_module()).graph_module - for node in graph_module.graph.nodes: - if node.op == "call_function": - node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE - return graph_module - - -def _propagate_shape_dim_orders_from_users(graph_module: torch.fx.GraphModule) -> None: - output_node = next(node for node in graph_module.graph.nodes if node.op == "output") - output_node.meta["tosa_dim_order"] = (0,) - dummy_exported = torch.export.export(torch.nn.Identity(), (torch.randn(1),)) - tosa_memory_format_pass = ToTosaMemoryFormatPass(dummy_exported) - tosa_memory_format_pass._propagate_dim_order_to_shape_args(output_node) - - -def _serialize_graph_module_to_tosa(graph_module: torch.fx.GraphModule): - tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+shape") - node_visitors = get_node_visitors(None, tosa_spec) - tosa_graph = ts.TosaSerializer( - "", - targetMajor=tosa_spec.version.major, - targetMinor=tosa_spec.version.minor, - targetPatch=tosa_spec.version.micro, - targetDraft=True, - ) - - for node in graph_module.graph.nodes: - if node.op == "call_function": - process_call_function(node, tosa_graph, node_visitors, tosa_spec) - - return tosa_graph - - -def test_unused_shape_ops_miss_tosa_dim_order_and_must_be_removed_before_tosa_serialization(): - graph_module = _graph_module_with_unused_const_shape() - _propagate_shape_dim_orders_from_users(graph_module) - - const_shape_nodes = [ - node - for node in graph_module.graph.nodes - if node.op == "call_function" - and node.target == exir_ops.backend.tosa.CONST_SHAPE.default - ] - dead_const_shape, live_const_shape = const_shape_nodes - - assert dead_const_shape.users == {} - assert "tosa_dim_order" not in dead_const_shape.meta - assert live_const_shape.meta["tosa_dim_order"] == (0,) - - with pytest.raises(KeyError, match="tosa_dim_order"): - _serialize_graph_module_to_tosa(graph_module) - - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - - remaining_const_shape = next( - node - for node in graph_module.graph.nodes - if node.op == "call_function" - and node.target == exir_ops.backend.tosa.CONST_SHAPE.default - ) - assert remaining_const_shape.meta["tosa_dim_order"] == (0,) - assert _serialize_graph_module_to_tosa(graph_module) diff --git a/backends/arm/test/misc/test_high_rank_permute_view_invariants.py b/backends/arm/test/misc/test_high_rank_permute_view_invariants.py index 79e50d337c3..6004553141a 100644 --- a/backends/arm/test/misc/test_high_rank_permute_view_invariants.py +++ b/backends/arm/test/misc/test_high_rank_permute_view_invariants.py @@ -159,7 +159,7 @@ def _build_high_rank_permute_cases() -> dict[str, TransposeInvariantCase]: 20260225 ) # nosec B311: deterministic RNG for test case generation start_shape = [1, 16, 16, 64] - expected_transpose_counts = [6, 11, 10, 10, 7, 7, 10, 10, 8, 10] + expected_transpose_counts = [4, 3, 3, 3, 2, 3, 3, 3, 3, 2] cases: dict[str, TransposeInvariantCase] = {} for idx in range(10): ops = _generate_chain(rng, start_shape, steps=8) diff --git a/backends/arm/test/misc/test_process_node.py b/backends/arm/test/misc/test_process_node.py new file mode 100644 index 00000000000..1ef348abdbf --- /dev/null +++ b/backends/arm/test/misc/test_process_node.py @@ -0,0 +1,96 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import tosa_serializer as ts +from executorch.backends.arm.process_node import process_placeholder +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir import to_edge +from torch._export.utils import is_param + + +class Int32BiasModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bias = torch.nn.Parameter( + torch.tensor([1, -2, 0x12345678], dtype=torch.int32), + requires_grad=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Keep the int32 parameter live in the exported graph. + return x + self.bias[0].to(torch.float32) + + +class CapturingTosaGraph: + def __init__(self) -> None: + self.shape = None + self.dtype = None + self.values = None + self.name = None + self.serialized_bytes = None + + def addConst(self, shape, dtype, values, name): + self.shape = shape + self.dtype = dtype + self.values = np.asarray(values) + self.name = name + if dtype == ts.DType.INT48: + self.serialized_bytes = self._serialize_int48(self.values) + + @staticmethod + def _serialize_int48(values: np.ndarray) -> list[int]: + # Simulate a consumer that expects each element pre-normalized to int64 + # before narrowing to little-endian signed 48-bit storage. + if values.dtype == np.int64: + packed: list[int] = [] + for value in values.reshape(-1): + masked = int(value) & ((1 << 48) - 1) + packed.extend([(masked >> (8 * i)) & 0xFF for i in range(6)]) + return packed + + # Existing buggy path: treat raw bytes as already packed int48 stream. + raw = values.view(np.uint8).reshape(-1).tolist() + remainder = len(raw) % 6 + if remainder: + raw.extend([0] * (6 - remainder)) + return raw + + +def _expected_int48_bytes(values: torch.Tensor) -> list[int]: + packed: list[int] = [] + for value in values.tolist(): + masked = int(value) & ((1 << 48) - 1) + packed.extend([(masked >> (8 * i)) & 0xFF for i in range(6)]) + return packed + + +def test_process_placeholder_int48_normalizes_int32_const_values() -> None: + module = Int32BiasModule().eval() + exported_program = torch.export.export(module, (torch.randn(1),)) + edge_program = to_edge(exported_program).exported_program() + + param_node = next( + node + for node in edge_program.graph.nodes + if node.op == "placeholder" and is_param(edge_program, node) + ) + param_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48 + + tosa_graph = CapturingTosaGraph() + process_placeholder( + param_node, + tosa_graph, + edge_program, + containing_graph_module=None, + tosa_spec=TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + ) + + assert tosa_graph.dtype == ts.DType.INT48 + assert tosa_graph.values is not None + assert tosa_graph.values.dtype == np.int64 + assert tosa_graph.serialized_bytes == _expected_int48_bytes(module.bias) diff --git a/backends/arm/test/misc/test_tosa_dialect_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_conv2d.py index 7cee50385c7..6f481d15ffd 100644 --- a/backends/arm/test/misc/test_tosa_dialect_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_conv2d.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -27,7 +27,7 @@ def test_conv2d_tosa_INT(): [2, 2, 2, 2], [1, 1], ), - (1, 8, 20, 20), + (1, 11, 20, 8), torch.int32, ), ( @@ -39,7 +39,7 @@ def test_conv2d_tosa_INT(): [2, 2, 2, 2], [1, 1], ), - (1, 4, 10, 10), + (1, 6, 10, 4), torch.int32, ), ] @@ -129,7 +129,7 @@ def test_conv2d_tosa_FP(): [2, 2, 2, 2], [1, 1], ), - (1, 8, 20, 20), + (1, 11, 20, 8), torch.float32, ), ( @@ -141,7 +141,7 @@ def test_conv2d_tosa_FP(): [2, 2, 2, 2], [1, 1], ), - (1, 4, 10, 10), + (1, 6, 10, 4), torch.float32, ), ] diff --git a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py index 3a6b5cef0fb..c680f1bd7e3 100644 --- a/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py +++ b/backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -28,7 +28,7 @@ def test_depthwise_conv2d_tosa_INT(): [2, 2, 2, 2], [1, 1], ), - (1, 16, 20, 20), + (1, 8, 23, 40), torch.int32, ), ( @@ -41,7 +41,7 @@ def test_depthwise_conv2d_tosa_INT(): [2, 2, 2, 2], [1, 1], ), - (1, 32, 10, 10), + (1, 4, 11, 40), torch.int32, ), ] @@ -134,7 +134,7 @@ def test_depthwise_conv2d_tosa_FP(): [2, 2, 2, 2], [1, 1], ), - (1, 16, 20, 20), + (1, 8, 23, 40), torch.float32, ), ( @@ -147,7 +147,7 @@ def test_depthwise_conv2d_tosa_FP(): [2, 2, 2, 2], [1, 1], ), - (1, 32, 10, 10), + (1, 4, 11, 40), torch.float32, ), ] diff --git a/backends/arm/test/misc/test_transpose_counts.py b/backends/arm/test/misc/test_transpose_counts.py index 478e5c03eaf..55496c8a9b5 100644 --- a/backends/arm/test/misc/test_transpose_counts.py +++ b/backends/arm/test/misc/test_transpose_counts.py @@ -335,13 +335,13 @@ def forward(self, x): "conv1d_rank3": TransposeCountCase(Conv1dModule(), (torch.randn(1, 2, 8),), 2), "conv2d_rank3": TransposeCountCase(Conv2dModule(), (torch.randn(2, 8, 8),), 2), "conv2d_rank4": TransposeCountCase(Conv2dModule(), (torch.randn(1, 2, 8, 8),), 2), - "conv3d_rank4": TransposeCountCase(Conv3dModule(), (torch.randn(2, 6, 6, 6),), 5), + "conv3d_rank4": TransposeCountCase(Conv3dModule(), (torch.randn(2, 6, 6, 6),), 2), "conv3d_rank5": TransposeCountCase( Conv3dModule(), (torch.randn(1, 2, 6, 6, 6),), 2 ), "linear_rank2": TransposeCountCase(LinearModule(), (torch.randn(2, 8),), 0), "linear_rank3": TransposeCountCase(LinearModule(), (torch.randn(2, 2, 8),), 0), - "linear_rank4": TransposeCountCase(LinearModule(), (torch.randn(1, 2, 2, 8),), 3), + "linear_rank4": TransposeCountCase(LinearModule(), (torch.randn(1, 2, 2, 8),), 0), "matmul_rank2": TransposeCountCase( MatmulModule(), (torch.randn(2, 3), torch.randn(3, 4)), @@ -350,7 +350,7 @@ def forward(self, x): "matmul_rank4": TransposeCountCase( MatmulModule(), (torch.randn(2, 2, 2, 3), torch.randn(2, 2, 3, 4)), - 5, + 0, ), "index_put": TransposeCountCase( IndexPutModule(), @@ -368,7 +368,7 @@ def forward(self, x): "pixel_shuffle": TransposeCountCase( PixelShuffleModule(), (torch.randn(1, 8, 2, 2),), - 7, + 1, ), "index_select": TransposeCountCase( IndexSelectModule(), @@ -378,23 +378,23 @@ def forward(self, x): "grouped_conv": TransposeCountCase( GroupedConvModule(), (torch.randn(1, 4, 8, 8),), - 2, + 4, ), "transpose_conv": TransposeCountCase( TransposeConvModule(), (torch.randn(1, 2, 8, 8),), 2, ), - "views": TransposeCountCase(ViewsModule(), (torch.rand(1, 2, 2, 4),), 6), + "views": TransposeCountCase(ViewsModule(), (torch.rand(1, 2, 2, 4),), 4), "transposes": TransposeCountCase( TransposesModule(), (torch.randn(1, 2, 3, 4),), - 4, + 1, ), "maxpool2d_dilation": TransposeCountCase( MaxPool2dDilatedModule(), (torch.randn(1, 2, 8, 8),), - 8, + 4, ), "lstm": TransposeCountCase( LstmModule(), @@ -404,17 +404,17 @@ def forward(self, x): "groupnorm": TransposeCountCase( GroupNormModule(), (torch.randn(1, 4, 4, 4),), - 5, + 1, ), "multihead_attention_rank2": TransposeCountCase( MultiheadAttentionModule(), (torch.randn(4, 8),), - 14, + 4, ), "multihead_attention_rank3": TransposeCountCase( MultiheadAttentionModule(), (torch.randn(2, 4, 8),), - 22, + 8, ), "cumsum_rank3_dim0": TransposeCountCase( CumsumModule(), @@ -424,22 +424,22 @@ def forward(self, x): "cumsum_rank4_dim3": TransposeCountCase( CumsumModule(), (torch.randn(1, 2, 3, 4), 3), - 3, + 0, ), "model_1_conv_maxpool_residual_linear": TransposeCountCase( Model1ConvMaxPoolResidualLinear(), (torch.randn(2, 8, 64),), 5 ), "model_2_conv_mha_linear_layernorm": TransposeCountCase( - Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 27 + Model2ConvMhaLinearLayerNorm(), (torch.randn(2, 8, 32),), 11 ), "model_3_lstm_linear": TransposeCountCase( Model3LstmLinear(), (torch.randn(2, 16, 8),), 2 ), "model_4_conv_lstm_linear_layernorm": TransposeCountCase( - Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 7 + Model4ConvLstmLinearLayerNorm(), (torch.randn(2, 8, 32),), 5 ), "model_5_dwconv_gelu_layernorm_avgpool": TransposeCountCase( - Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 4 + Model5DwConvGeluLayerNormAvgPool(), (torch.randn(1, 8, 16, 16),), 6 ), "model_6_gru_linear": TransposeCountCase( Model6GruLinear(), (torch.randn(2, 16, 8),), 2 @@ -448,10 +448,10 @@ def forward(self, x): Model7DwConvBatchNormLinear(), (torch.randn(2, 8, 64),), 3 ), "model_8_conv_batchnorm_maxpool_residual": TransposeCountCase( - Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 2 + Model8ConvBatchNormMaxPoolResidual(), (torch.randn(1, 8, 16, 16),), 6 ), "model_9_dilated_conv_batchnorm_avgpool_residual": TransposeCountCase( - Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 2 + Model9DilatedConvBatchNormAvgPoolResidual(), (torch.randn(1, 8, 16, 16),), 6 ), "model_10_dwconv_batchnorm_linear_cat": TransposeCountCase( Model10DwConvBatchNormLinearCat(), (torch.randn(2, 8, 64),), 3 @@ -468,17 +468,17 @@ def forward(self, x): "conv3d_rank4_channels_last": TransposeCountCase( Conv3dModule(), (torch.randn(2, 6, 6, 6).to(memory_format=torch.channels_last),), - 4, + 2, ), "conv3d_rank5_channels_last": TransposeCountCase( Conv3dModule(), (torch.randn(1, 2, 6, 6, 6).to(memory_format=torch.channels_last_3d),), - 1, + 3, ), "linear_rank4_channels_last": TransposeCountCase( LinearModule(), (torch.randn(1, 2, 2, 8).to(memory_format=torch.channels_last),), - -1, # The test crashes before reaching the transpose count + 1, ), "matmul_rank4_channels_last": TransposeCountCase( MatmulModule(), @@ -486,17 +486,17 @@ def forward(self, x): torch.randn(2, 2, 2, 3).to(memory_format=torch.channels_last), torch.randn(2, 2, 3, 4).to(memory_format=torch.channels_last), ), - -1, # The test crashes before reaching the transpose count + 2, ), "pixel_shuffle_channels_last": TransposeCountCase( PixelShuffleModule(), (torch.randn(1, 8, 2, 2).to(memory_format=torch.channels_last),), - 5, + 3, ), "grouped_conv_channels_last": TransposeCountCase( GroupedConvModule(), (torch.randn(1, 4, 8, 8).to(memory_format=torch.channels_last),), - 0, + 3, ), "transpose_conv_channels_last": TransposeCountCase( TransposeConvModule(), @@ -511,7 +511,7 @@ def forward(self, x): "transposes_channels_last": TransposeCountCase( TransposesModule(), (torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last),), - 3, + 1, ), "maxpool2d_dilation_channels_last": TransposeCountCase( MaxPool2dDilatedModule(), @@ -521,12 +521,12 @@ def forward(self, x): "groupnorm_channels_last": TransposeCountCase( GroupNormModule(), (torch.randn(1, 4, 4, 4).to(memory_format=torch.channels_last),), - 4, + 3, ), "cumsum_rank4_dim3_channels_last": TransposeCountCase( CumsumModule(), (torch.randn(1, 2, 3, 4).to(memory_format=torch.channels_last), 3), - -1, # The test crashes before reaching the transpose count + 1, ), } @@ -540,19 +540,12 @@ def test_transpose_counts_tosa_FP(case: TransposeCountCase) -> None: xfails = { "conv3d_rank5_channels_last": "Numerical error", - "linear_rank4_channels_last": "DecomposeLinearPass: Tries inserting a view not supported in channels last format", - "matmul_rank4_channels_last": "ToTosaMemoryFormatPass: Tries inserting view not supported in channels last format", "views_channels_last": "Torch.export: View not supported by torch.export in channels last format", - "cumsum_rank4_dim3_channels_last": "DecomposeCumssumPass: Tries inserting a view not supported in channels last format", } @common.parametrize("case", cases_channels_last, xfails=xfails) # type: ignore[arg-type] def test_transpose_counts_tosa_FP_channels_last(case: TransposeCountCase) -> None: - pipeline = TosaPipelineFP[InputT]( - case.module, - case.inputs, - aten_op=[], - ) + pipeline = TosaPipelineFP[InputT](case.module, case.inputs, aten_op=[]) pipeline.count_tosa_ops({"TRANSPOSE": case.expected_transposes}) pipeline.run() diff --git a/backends/arm/test/ops/test_pow.py b/backends/arm/test/ops/test_pow.py index be452ed70ac..2d007fa7e68 100644 --- a/backends/arm/test/ops/test_pow.py +++ b/backends/arm/test/ops/test_pow.py @@ -145,13 +145,18 @@ def test_pow_tensor_tensor_vgf_no_quant(test_data: Pow_TensorTensor.input_t): "exp_two": "TOSA constraints: If x <0 .", } +x_fail_FP = { + "exp_two": "TOSA constraints: If x <0 .", + "exp_zero": "MLETORCH-2041 : Invalid inputs.", +} + @common.parametrize( "test_data", Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16 | Pow_TensorScalar.test_data_bf16, - xfails=x_fail, + xfails=x_fail_FP, strict=False, ) def test_pow_tensor_scalar_tosa_FP(test_data: Pow_TensorScalar.input_t): @@ -207,7 +212,7 @@ def test_pow_tensor_scalar_u85_INT(test_data: Pow_TensorScalar.input_t): @common.parametrize( "test_data", Pow_TensorScalar.test_data | Pow_TensorScalar.test_data_fp16, - x_fail, + x_fail_FP, strict=False, ) @common.SkipIfNoModelConverter diff --git a/backends/arm/test/ops/test_unfold_copy.py b/backends/arm/test/ops/test_unfold_copy.py index 5765fcb8d02..2b502a9be10 100644 --- a/backends/arm/test/ops/test_unfold_copy.py +++ b/backends/arm/test/ops/test_unfold_copy.py @@ -161,12 +161,6 @@ def test_unfold_copy_u55_INT(test_data: input_params): @common.parametrize( "test_data", test_data_int | test_data_fp, - xfails={ - "test_int8_3d_dim_neg1": "MLETORCH-1732: rand test fails", - "test_int32_4d_dim_neg1": "MLETORCH-1732: rand test fails", - "test_fp32_3d_dim_neg1": "MLETORCH-1732: rand test fails", - "test_fp32_4d_dim_neg1": "MLETORCH-1732: rand test fails", - }, ) @common.XfailIfNoCorstone320 def test_unfold_copy_u85_INT(test_data: input_params): diff --git a/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py b/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py new file mode 100644 index 00000000000..42214ba59b3 --- /dev/null +++ b/backends/arm/test/passes/test_rewrite_avg_pool2d_pass.py @@ -0,0 +1,126 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, Protocol, Tuple + +import torch +from executorch.backends.arm._passes.rewrite_avg_pool2d_pass import RewriteAvgPool2dPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops + +input_t = Tuple[torch.Tensor] + + +class ModuleWithInputs(Protocol): + def get_inputs(self) -> input_t: ... + + +class AvgPool2dWithStride(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + +class AvgPool2dWithoutStride(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.avg_pool2d(x, kernel_size=3) + + +class AvgPool2dListKernel(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3]) + + +class AvgPool2dScalarPadding(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=2, padding=1) + + +class AvgPool2dWithEmptyStride(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.avg_pool2d(x, kernel_size=[2, 3], stride=[]) + + +modules: Dict[str, ModuleWithInputs] = { + "avg_pool2d_with_stride": AvgPool2dWithStride(), + "avg_pool2d_without_stride": AvgPool2dWithoutStride(), + "avg_pool2d_list_kernel": AvgPool2dListKernel(), +} + + +@common.parametrize("module", modules) +def test_rewrite_avg_pool2d_tosa(module: ModuleWithInputs) -> None: + nn_module = cast(torch.nn.Module, module) + pipeline = PassPipeline[input_t]( + nn_module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + pass_list=[RewriteAvgPool2dPass], + ) + pipeline.pop_stage( + "run_method_and_compare_outputs" + ) # Cannot run aten graph with tosa dialect ops + pipeline.run() + + +def _get_tosa_avg_pool2d_node( + pipeline: PassPipeline[input_t], +) -> torch.fx.Node: + exported_program = pipeline.tester.get_artifact( + StageType.RUN_PASSES + ).exported_program() + graph_module = exported_program.graph_module + + tosa_nodes = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.AVG_POOL2D.default + ] + assert len(tosa_nodes) == 1 + return tosa_nodes[0] + + +def test_rewrite_avg_pool2d_tosa_empty_stride_uses_kernel_size() -> None: + module = AvgPool2dWithEmptyStride() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_default": 1, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 2, + }, + pass_list=[RewriteAvgPool2dPass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + tosa_node = _get_tosa_avg_pool2d_node(pipeline) + assert tosa_node.args[4] == [2, 3] diff --git a/backends/arm/test/passes/test_rewrite_conv_pass.py b/backends/arm/test/passes/test_rewrite_conv_pass.py index 5e8593a38af..09176f26f28 100644 --- a/backends/arm/test/passes/test_rewrite_conv_pass.py +++ b/backends/arm/test/passes/test_rewrite_conv_pass.py @@ -15,6 +15,7 @@ ) from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, VgfQuantizer, ) @@ -22,6 +23,7 @@ DWConvsModule, ) from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -59,6 +61,10 @@ def _compile_spec() -> VgfCompileSpec: return VgfCompileSpec("TOSA-1.0+INT+FP") +def _compile_spec_int16() -> VgfCompileSpec: + return VgfCompileSpec("TOSA-1.0+INT+FP+int16") + + def _quantizer() -> VgfQuantizer: quantizer = VgfQuantizer(_compile_spec()) quantizer.set_global( @@ -80,6 +86,14 @@ def _export_quantized(model: nn.Module): return torch.export.export(quantized, inputs) +def _export_quantized_a16w8(model: nn.Module, inputs: tuple[torch.Tensor, ...]): + exported = torch.export.export(model.eval(), inputs).module(check_guards=False) + quantizer = VgfQuantizer(_compile_spec_int16()) + quantizer.set_global(get_symmetric_a16w8_quantization_config()) + quantized = quantizer._quantize_with_submodules(exported, [inputs]) + return torch.export.export(quantized, inputs) + + def _run_pre_rewrite_passes(exported_program: torch.export.ExportedProgram): gm = exported_program.graph_module for pass_ in ( @@ -110,6 +124,62 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) +class Conv2dBiasModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(4, 6, kernel_size=3, stride=1, padding=1, bias=True) + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 4, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class DepthwiseConv2dBiasModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=4, bias=True) + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 4, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class Conv3dBiasModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv3d(3, 5, kernel_size=3, stride=1, padding=1, bias=True) + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 3, 6, 6, 6),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class TransposeConv2dBiasModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.ConvTranspose2d( + 3, + 4, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=True, + ) + + def get_inputs(self) -> tuple[torch.Tensor]: + return (torch.randn(1, 3, 6, 6),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + def _make_rewrite_pass( example_inputs: tuple[torch.Tensor, ...], dynamic_shapes: dict[int, object] | None = None, @@ -190,6 +260,48 @@ def test_rewrite_conv_vgf_quant_infers_quantized_bias_dtype_from_inputs() -> Non assert bias_nodes[0].meta["val"].dtype == torch.int32 +@pytest.mark.parametrize( + "module,target_op", + [ + (Conv2dBiasModule(), exir_ops.backend.tosa.CONV2D.default), + (DepthwiseConv2dBiasModule(), exir_ops.backend.tosa.DEPTHWISE_CONV2D.default), + (Conv3dBiasModule(), exir_ops.backend.tosa.CONV3D.default), + (TransposeConv2dBiasModule(), exir_ops.backend.tosa.TRANSPOSE_CONV2D.default), + ], +) +def test_rewrite_conv_int16_bias_lowers_to_single_tosa_conv( + module: ( + Conv2dBiasModule + | DepthwiseConv2dBiasModule + | Conv3dBiasModule + | TransposeConv2dBiasModule + ), + target_op, +) -> None: + exported_program = _export_quantized_a16w8(module, module.get_inputs()) + edge_program = to_edge( + exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ).exported_program() + gm = _run_pre_rewrite_passes(edge_program) + + with TosaLoweringContext(_compile_spec_int16().tosa_spec): + result = RewriteConvPass(edge_program)(gm) + assert result is not None + gm = result.graph_module + + tosa_conv_nodes = [ + node + for node in gm.graph.nodes + if node.op == "call_function" and node.target == target_op + ] + assert len(tosa_conv_nodes) == 1 + assert all(node.target != exir_ops.edge.aten.add.Tensor for node in gm.graph.nodes) + + bias_node = tosa_conv_nodes[0].args[2] + assert isinstance(bias_node, torch.fx.Node) + assert bias_node.meta.get(TosaSpecialDtype.meta_key()) == TosaSpecialDtype.INT48 + + def test_rewrite_conv_dynamic_keeps_static_padding_when_symbolic_remainder_is_zero(): model = ConvModule() example_inputs = (torch.randn(1, 3, 9, 12),) diff --git a/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py b/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py index c77f315903f..4b770b3ee20 100644 --- a/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py +++ b/backends/arm/test/passes/test_rewrite_max_pool2d_pass.py @@ -10,6 +10,8 @@ from executorch.backends.arm._passes.rewrite_max_pool2d_pass import RewriteMaxPool2dPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops input_t = Tuple[torch.Tensor] @@ -42,6 +44,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3]) +class MaxPool2dWithEmptyStride(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.max_pool2d(x, kernel_size=[2, 3], stride=[]) + + modules: Dict[str, ModuleWithInputs] = { "max_pool2d_with_stride": MaxPool2dWithStride(), "max_pool2d_without_stride": MaxPool2dWithoutStride(), @@ -67,3 +77,41 @@ def test_rewrite_max_pool2d_tosa(module: ModuleWithInputs) -> None: "run_method_and_compare_outputs" ) # Cannnot run aten graph with tosa dialect ops pipeline.run() + + +def _get_tosa_max_pool2d_node( + pipeline: PassPipeline[input_t], +) -> torch.fx.Node: + exported_program = pipeline.tester.get_artifact( + StageType.RUN_PASSES + ).exported_program() + graph_module = exported_program.graph_module + + tosa_nodes = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.MAX_POOL2D.default + ] + assert len(tosa_nodes) == 1 + return tosa_nodes[0] + + +def test_rewrite_max_pool2d_tosa_empty_stride_uses_kernel_size() -> None: + module = MaxPool2dWithEmptyStride() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_backend__ops_tosa_MAX_POOL2D_default": 1, + }, + pass_list=[RemoveGetItemPass, RewriteMaxPool2dPass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + tosa_node = _get_tosa_max_pool2d_node(pipeline) + assert tosa_node.args[2] == [2, 3] diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 0d1dfb4dfa1..9b62e081ad4 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -38,7 +38,6 @@ from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.dim_order_utils import get_memory_format from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -119,12 +118,8 @@ def _sort_key(t: Node) -> int: def _get_matching_fake_tensor(node: Node): - """Return a fake tensor with the same properties as node, but with - .dim_order() == node.meta["tosa_dim_order"] - """ - fake_tensor = node.meta["val"] - desired_dim_order = node.meta["tosa_dim_order"] - return fake_tensor.to(memory_format=get_memory_format(list(desired_dim_order))) + """Return the fake tensor of node.""" + return node.meta["val"] def arm_get_first_delegation_tag(graph_module) -> str: diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d.py b/backends/arm/tosa/dialect/ops/avg_pool2d.py index 1a9192048a8..e05db45d3a4 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d.py @@ -78,15 +78,16 @@ def AVG_POOL2D( f"must be one of {acc_allowed}", op="AVG_POOL2D", ) - # Unpack dimensions and parameters; zero-points are not used for shape - n, c, h, w = x.shape + # Unpack dimensions and parameters in NHWC order; + # zero-points are not used for shape. + n, h, w, c = x.shape k_h, k_w = kernel s_h, s_w = stride - p_top, p_left, p_bot, p_right = pad + p_top, p_bottom, p_left, p_right = pad # Compute output spatial dimensions (floor division) - h_out = (h + p_top + p_left - k_h) // s_h + 1 - w_out = (w + p_bot + p_right - k_w) // s_w + 1 + h_out = (h + p_top + p_bottom - k_h) // s_h + 1 + w_out = (w + p_left + p_right - k_w) // s_w + 1 # Return a tensor with the computed shape and dtype - return torch.empty(size=[n, c, h_out, w_out], dtype=x.dtype) + return torch.empty(size=[n, h_out, w_out, c], dtype=x.dtype) diff --git a/backends/arm/tosa/dialect/ops/conv2d.py b/backends/arm/tosa/dialect/ops/conv2d.py index 2b991600994..841a1d90876 100644 --- a/backends/arm/tosa/dialect/ops/conv2d.py +++ b/backends/arm/tosa/dialect/ops/conv2d.py @@ -101,14 +101,14 @@ def CONV2D( torch_pad = [pad[0], pad[2]] N = x.shape[0] C_out = weight.shape[0] - H_in, W_in = x.shape[2:] + H_in, W_in = x.shape[1], x.shape[2] H_out = math.floor( - (H_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[2] - 1) - 1) / stride[0] + (H_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[1] - 1) - 1) / stride[0] + 1 ) W_out = math.floor( - (W_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[3] - 1) - 1) / stride[1] + (W_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[2] - 1) - 1) / stride[1] + 1 ) - output_shape = [N, C_out, H_out, W_out] + output_shape = [N, H_out, W_out, C_out] return torch.empty(size=output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/conv3d.py b/backends/arm/tosa/dialect/ops/conv3d.py index bf316c3d52a..67ceb0596c6 100644 --- a/backends/arm/tosa/dialect/ops/conv3d.py +++ b/backends/arm/tosa/dialect/ops/conv3d.py @@ -54,18 +54,18 @@ def CONV3D( torch_pad = [pad[0], pad[2], pad[4]] N = x.shape[0] C_out = weight.shape[0] - D_in, H_in, W_in = x.shape[2:] + D_in, H_in, W_in = x.shape[1], x.shape[2], x.shape[3] D_out = math.floor( - (D_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[2] - 1) - 1) / stride[0] + (D_in + 2 * torch_pad[0] - dilation[0] * (weight.shape[1] - 1) - 1) / stride[0] + 1 ) H_out = math.floor( - (H_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[3] - 1) - 1) / stride[1] + (H_in + 2 * torch_pad[1] - dilation[1] * (weight.shape[2] - 1) - 1) / stride[1] + 1 ) W_out = math.floor( - (W_in + 2 * torch_pad[2] - dilation[2] * (weight.shape[4] - 1) - 1) / stride[2] + (W_in + 2 * torch_pad[2] - dilation[2] * (weight.shape[3] - 1) - 1) / stride[2] + 1 ) - output_shape = [N, C_out, D_out, H_out, W_out] + output_shape = [N, D_out, H_out, W_out, C_out] return torch.empty(size=output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py index 7d8d5f9edc8..ae864f29d62 100644 --- a/backends/arm/tosa/dialect/ops/depthwise_conv2d.py +++ b/backends/arm/tosa/dialect/ops/depthwise_conv2d.py @@ -39,15 +39,16 @@ def DEPTHWISE_CONV2D( ) torch_pad = [pad[0], pad[2]] - kernel_h, kernel_w = weight.shape[0], weight.shape[2] - C_out = weight.shape[1] * x.shape[1] + # Weight format is [KH, KW, IC, M], where C_out = IC * M. + kernel_h, kernel_w = weight.shape[0], weight.shape[1] + C_out = weight.shape[2] * weight.shape[3] N = x.shape[0] - H_in, W_in = x.shape[2:] + H_in, W_in = x.shape[1], x.shape[2] H_out = math.floor( (H_in + 2 * torch_pad[0] - dilation[0] * (kernel_h - 1) - 1) / stride[0] + 1 ) W_out = math.floor( (W_in + 2 * torch_pad[1] - dilation[1] * (kernel_w - 1) - 1) / stride[1] + 1 ) - output_shape = [N, C_out, H_out, W_out] + output_shape = [N, H_out, W_out, C_out] return torch.empty(size=output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/dialect/ops/max_pool2d.py b/backends/arm/tosa/dialect/ops/max_pool2d.py index a0559937719..161a74ef170 100644 --- a/backends/arm/tosa/dialect/ops/max_pool2d.py +++ b/backends/arm/tosa/dialect/ops/max_pool2d.py @@ -64,7 +64,7 @@ def MAX_POOL2D( op="MAX_POOL2D", ) - n, c, h, w = x.shape + n, h, w, c = x.shape k_h, k_w = kernel s_h, s_w = stride # TOSA MAX_POOL2D pad order is [top, bottom, left, right] @@ -72,4 +72,4 @@ def MAX_POOL2D( h_out = (h + p_top + p_bot - k_h) // s_h + 1 w_out = (w + p_left + p_right - k_w) // s_w + 1 - return torch.empty(size=[n, c, h_out, w_out], dtype=x.dtype) + return torch.empty(size=[n, h_out, w_out, c], dtype=x.dtype) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index c4e720cd849..f8b078c8690 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -86,7 +86,7 @@ def RESIZE( scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale offset_y, offset_x = offset border_y, border_x = border - H, W = input_shape[2], input_shape[3] + H, W = input_shape[1], input_shape[2] # RESIZE first upscales the input by an integer value, to "upscale space". H_upscaled = (H - 1) * scale_y_n # offset and border are provided in this scale, therefore adjust for these while in this space. @@ -97,6 +97,8 @@ def RESIZE( W_upscaled = (W - 1) * scale_x_n W_shifted = W_upscaled - offset_x + border_x OW = (W_shifted // scale_x_d) + 1 - fake_aten_tensor = torch.empty(size=(*input_shape[:2], OH, OW), dtype=output_dtype) + fake_aten_tensor = torch.empty( + size=(input_shape[0], OH, OW, input_shape[3]), dtype=output_dtype + ) return fake_aten_tensor diff --git a/backends/arm/tosa/dialect/ops/transpose_conv2d.py b/backends/arm/tosa/dialect/ops/transpose_conv2d.py index 9a85b6e379c..9df0245dd14 100644 --- a/backends/arm/tosa/dialect/ops/transpose_conv2d.py +++ b/backends/arm/tosa/dialect/ops/transpose_conv2d.py @@ -46,12 +46,13 @@ def TRANSPOSE_CONV2D( ) N = x.shape[0] - C_out = weight.shape[1] - H_in, W_in = x.shape[2:] - kernel_h = weight.shape[2] - kernel_w = weight.shape[3] + # Weight format is [OC, KH, KW, IC]. + C_out = weight.shape[0] + H_in, W_in = x.shape[1], x.shape[2] + kernel_h = weight.shape[1] + kernel_w = weight.shape[2] H_out = (H_in - 1) * stride[0] + out_pad[0] + out_pad[1] + kernel_h W_out = (W_in - 1) * stride[1] + out_pad[2] + out_pad[3] + kernel_w - output_shape = [N, C_out, H_out, W_out] + output_shape = [N, H_out, W_out, C_out] return torch.empty(size=output_shape, dtype=output_dtype) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 6c7d1532218..5d95399e5d5 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -132,7 +132,8 @@ def extract_tensor_meta(meta): special_dtype = meta.get(TosaSpecialDtype.meta_key()) if special_dtype == TosaSpecialDtype.SHAPE: shape_len = len(meta["val"]) - return (ts.DType.SHAPE, (shape_len,), meta["tosa_dim_order"]) + dim_order = tuple(range(shape_len)) + return (ts.DType.SHAPE, (shape_len,), dim_order) if meta.get("val") is None: raise ValueError("Expected node.meta['val'] to be set to a FakeTensor") @@ -153,10 +154,7 @@ def extract_tensor_meta(meta): dtype = map_dtype(val.dtype) shape = tuple(val.size()) - if meta.get("tosa_dim_order") is not None: - dim_order = meta["tosa_dim_order"] - else: - dim_order = tuple(range(len(shape))) + dim_order = tuple(range(len(shape))) return (dtype, shape, dim_order) diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 602a9548791..6fddff42f9a 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -164,23 +164,20 @@ def build_reshape_tosa( def tosa_shape(shape, dim_order): - """Reorder a shape tuple into TOSA layout while resolving symints. + """Convert a shape tuple to a TOSA shape list while resolving symints. Args: shape (Sequence[int | torch.SymInt]): Original tensor shape, possibly containing ``torch.SymInt``. - dim_order (Sequence[int]): Desired dimension order for the output - shape. + dim_order (Sequence[int]): Kept for API compatibility. Shape lowering + now uses the original tensor order directly. Returns: - list[int]: List containing the reordered dimensions where symbolic + list[int]: List containing dimensions in original order where symbolic values become ``-1``. """ - reordered = tuple([shape[dim] for dim in dim_order]) - # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes, - # in TOSA we do not have this concept and instead use -1. - removed_symints = tuple( - [-1 if isinstance(d, torch.SymInt) else d for d in reordered] - ) + # Dynamic shapes in executorch are represented with torch.SymInt objects in + # the shapes, in TOSA we do not have this concept and instead use -1. + removed_symints = tuple([-1 if isinstance(d, torch.SymInt) else d for d in shape]) return list(removed_symints) diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py index dd28b13045d..9892fdb28f0 100644 --- a/backends/transforms/remove_permutes_around_elementwise_ops.py +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -59,6 +60,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: op="call_function", target=exir_ops.edge.aten.permute_copy.default ): start_permute = self.get_permutation(node) + if start_permute is None: + continue # Expected end permutation for the subgraph. end_permute = [start_permute.index(i) for i in range(len(start_permute))] @@ -264,9 +267,22 @@ def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> No else: node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) - def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: + def get_permutation(self, permute_node: torch.fx.Node) -> list[int] | None: assert permute_node.target == exir_ops.edge.aten.permute_copy.default + raw_permute: list[int] if len(permute_node.args) >= 2: - return cast(list[int], permute_node.args[1]) - assert "dim" in permute_node.kwargs - return cast(list[int], permute_node.kwargs["dim"]) + raw_permute = list(cast(list[int], permute_node.args[1])) + else: + raw_dims = permute_node.kwargs.get("dims", permute_node.kwargs.get("dim")) + if raw_dims is None: + return None + raw_permute = list(cast(list[int], raw_dims)) + + rank = len(raw_permute) + normalized_permute = [d + rank if d < 0 else d for d in raw_permute] + + if not all(0 <= d < rank for d in normalized_permute): + return None + if sorted(normalized_permute) != list(range(rank)): + return None + return normalized_permute