Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
from .insert_const_shapes import InsertConstShapesPass # noqa
from .insert_data_layout_casts_pass import InsertDataLayoutCastsPass # noqa
from .insert_dynamic_padding import InsertDynamicPaddingPass # noqa
from .insert_int32_casts_after_int64_placeholders import ( # noqa
InsertInt32CastsAfterInt64PlaceholdersPass,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
FuseViewCopyTransformPass,
InsertConstShapesPass,
InsertControlFlowRescalesPass,
InsertDataLayoutCastsPass,
InsertInt32CastsAfterInt64PlaceholdersPass,
InsertRescaleInt32Pass,
InsertRescalePass,
Expand Down Expand Up @@ -545,6 +546,7 @@ def _tosa_pipeline(
ToTosaMemoryFormatPass(exported_program),
RemoveNoopPass(),
InsertRescalePass(),
InsertDataLayoutCastsPass(),
]
)

Expand Down
129 changes: 129 additions & 0 deletions backends/arm/_passes/insert_data_layout_casts_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, NodeMetadata


class InsertDataLayoutCastsPass(ArmPass):
"""Insert casts around data layout operators when their dtype is not
supported by the active TOSA specification.

This pass targets operators that lower to TOSA data layout operators:
CONCAT, PAD, RESHAPE, REVERSE, SLICE, TILE, and TRANSPOSE.

Example:
Before pass:
y = transpose(x) # x.data.dtype == torch.int32
After pass:
xfp32 = _to_dim_order_copy(x, dtype=torch.float32)
yfp32 = transpose(xfp32)
y = _to_dim_order_copy(yfp32, dtype=torch.int32)

"""

_passes_required_after: Set[Type[ExportPass]] = set()

_cast_op = exir_ops.edge.dim_order_ops._to_dim_order_copy.default

_concat_ops = {
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.concatenate.default,
}
_single_input_ops = {
exir_ops.backend.tosa.TRANSPOSE.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.permute_copy.default,
Comment on lines +41 to +45
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InsertDataLayoutCastsPass is intended to cover PAD and SLICE (per docstring/PR description), but targeted_ops currently only includes the edge forms (aten.constant_pad_nd, aten.slice_copy) while the main Arm TOSA pipeline rewrites these to exir_ops.backend.tosa.PAD.default and exir_ops.backend.tosa.SLICE.default (via RewritePadPass/RewriteSlicePass) before this pass runs. As a result, casts will not be inserted around the PAD/SLICE nodes that actually reach serialization in the real pipeline. Include the backend TOSA PAD/SLICE targets in targeted_ops (and ideally add a unit test that runs RewritePadPass/RewriteSlicePass before this pass to prevent regressions).

Suggested change
exir_ops.backend.tosa.TRANSPOSE.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.backend.tosa.TRANSPOSE.default,
exir_ops.backend.tosa.PAD.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.backend.tosa.SLICE.default,

Copilot uses AI. Check for mistakes.
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.flip.default,
}
targeted_ops = _concat_ops | _single_input_ops

_fp_to_int_map = {
torch.float16: torch.int16,
torch.bfloat16: torch.int16,
torch.float32: torch.int32,
}

_int_to_fp_map = {
torch.int8: torch.float16, # This doubles the size after casting, but is very unlikely to occur in practice since int8 is only ever used by LOGICAL_SHIFT and CAST/RESCALE ops in PRO-FP.
torch.int16: torch.float16,
torch.int32: torch.float32,
}

def call_operator(self, op, args, kwargs, meta):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)

if op in self._concat_ops:
# Cast to largest dtype
dtypes = [arg.data.dtype for arg in args[0]]
dtype_sizes = [dtype.itemsize for dtype in dtypes]
dtype = dtypes[dtype_sizes.index(max(dtype_sizes))]
else:
dtype = args[0].data.dtype

spec = get_context_spec()
dtype_is_integer = not dtype.is_floating_point and dtype != torch.bool
if dtype_is_integer and not spec.support_integer():
supported_dtype = self._int_to_fp_map.get(dtype, None)
elif dtype.is_floating_point and not spec.support_float():
supported_dtype = self._fp_to_int_map.get(dtype, None)
else:
return super().call_operator(op, args, kwargs, meta)

# CONCATENATE does not support int16 w/o INT16 extension like other ops
if (
op in self._concat_ops
and supported_dtype == torch.int16
and not spec.support_extension("int16")
):
supported_dtype = None

if supported_dtype is None:
raise TypeError(
f"Data type {dtype} of operator {op} is not supported by"
f" {spec}, and casting is currently not supported by {self.__class__.__name__}."
)

if op in self._concat_ops:
x_casted = []
for arg in args[0]:
x_casted.append(
super().call_operator(
self._cast_op,
(arg,),
{"dtype": supported_dtype},
NodeMetadata(arg.node.meta),
updated=True,
)
)
y_casted = super().call_operator(
op, (x_casted, *args[1:]), kwargs, meta, updated=True
)

else:
x_casted = super().call_operator(
self._cast_op,
(args[0],),
{"dtype": supported_dtype},
NodeMetadata(args[0].node.meta),
updated=True,
)
y_casted = super().call_operator(
op, (x_casted, *args[1:]), kwargs, meta, updated=True
)

y = super().call_operator(
self._cast_op, (y_casted,), {"dtype": dtype}, meta, updated=True
)
return y
19 changes: 9 additions & 10 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,15 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
]
if self.tosa_spec.support_extension("int16"):
supported_dtypes.append(ts.DType.INT16)
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT32])
if self.tosa_spec.support_extension("int16"):
supported_dtypes.append(ts.DType.INT16)
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)
validate_num_inputs(self.target, inputs, [1, 2])
input_tosa_args = [TosaArg(arg, self.tosa_spec) for arg in inputs[0].special]
validate_same_dtype(self.target, [*input_tosa_args, output], ts)
Expand Down
18 changes: 9 additions & 9 deletions backends/arm/operators/op_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
supported_dtypes,
self.tosa_spec,
)

Expand Down
18 changes: 9 additions & 9 deletions backends/arm/operators/op_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ def define_node(
inputs: list[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
supported_dtypes,
self.tosa_spec,
)

Expand Down
22 changes: 22 additions & 0 deletions backends/arm/operators/op_tosa_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg


Expand All @@ -29,6 +34,23 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
supported_dtypes,
self.tosa_spec,
)

pad_const = tosa_graph.addConst(
[1],
output.dtype,
Expand Down
21 changes: 21 additions & 0 deletions backends/arm/operators/op_tosa_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from torch.fx import Node

Expand All @@ -30,6 +35,22 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
supported_dtypes,
self.tosa_spec,
)

input_node, starts, sizes = inputs
attr = ts.TosaSerializerAttribute()
Expand Down
18 changes: 9 additions & 9 deletions backends/arm/operators/op_tosa_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
supported_dtypes,
self.tosa_spec,
)

Expand Down
18 changes: 9 additions & 9 deletions backends/arm/operators/op_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL]
if self.tosa_spec.support_integer():
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
if self.tosa_spec.support_float():
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
if self.tosa_spec.support_extension("bf16"):
supported_dtypes.append(ts.DType.BF16)

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
ts.DType.BOOL,
],
supported_dtypes,
self.tosa_spec,
)

Expand Down
Loading
Loading