Skip to content

Commit 063f9c9

Browse files
Arm backend: Ensure data layout ops dtype validity (#18540)
Previously dtypes were not checked per tosa spec, leading to data layout ops (TRANSPOSE, RESHAPE...) sometimes being inserted as INT inf PRO-FP, which is invalid. This patch adds checks and a new pass inserting dtype casts before/after all data layout ops to ensure validity. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 490ec5c commit 063f9c9

11 files changed

Lines changed: 331 additions & 46 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
117117
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
118118
from .insert_const_shapes import InsertConstShapesPass # noqa
119+
from .insert_data_layout_casts_pass import InsertDataLayoutCastsPass # noqa
119120
from .insert_dynamic_padding import InsertDynamicPaddingPass # noqa
120121
from .insert_int32_casts_after_int64_placeholders import ( # noqa
121122
InsertInt32CastsAfterInt64PlaceholdersPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
FuseViewCopyTransformPass,
110110
InsertConstShapesPass,
111111
InsertControlFlowRescalesPass,
112+
InsertDataLayoutCastsPass,
112113
InsertInt32CastsAfterInt64PlaceholdersPass,
113114
InsertRescaleInt32Pass,
114115
InsertRescalePass,
@@ -545,6 +546,7 @@ def _tosa_pipeline(
545546
ToTosaMemoryFormatPass(exported_program),
546547
RemoveNoopPass(),
547548
InsertRescalePass(),
549+
InsertDataLayoutCastsPass(),
548550
]
549551
)
550552

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.backends.arm.tosa.specification import get_context_spec
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
13+
14+
15+
class InsertDataLayoutCastsPass(ArmPass):
16+
"""Insert casts around data layout operators when their dtype is not
17+
supported by the active TOSA specification.
18+
19+
This pass targets operators that lower to TOSA data layout operators:
20+
CONCAT, PAD, RESHAPE, REVERSE, SLICE, TILE, and TRANSPOSE.
21+
22+
Example:
23+
Before pass:
24+
y = transpose(x) # x.data.dtype == torch.int32
25+
After pass:
26+
xfp32 = _to_dim_order_copy(x, dtype=torch.float32)
27+
yfp32 = transpose(xfp32)
28+
y = _to_dim_order_copy(yfp32, dtype=torch.int32)
29+
30+
"""
31+
32+
_passes_required_after: Set[Type[ExportPass]] = set()
33+
34+
_cast_op = exir_ops.edge.dim_order_ops._to_dim_order_copy.default
35+
36+
_concat_ops = {
37+
exir_ops.edge.aten.cat.default,
38+
exir_ops.edge.aten.concatenate.default,
39+
}
40+
_single_input_ops = {
41+
exir_ops.backend.tosa.TRANSPOSE.default,
42+
exir_ops.edge.aten.constant_pad_nd.default,
43+
exir_ops.edge.aten.view_copy.default,
44+
exir_ops.edge.aten.repeat.default,
45+
exir_ops.edge.aten.permute_copy.default,
46+
exir_ops.edge.aten.slice_copy.Tensor,
47+
exir_ops.edge.aten.flip.default,
48+
}
49+
targeted_ops = _concat_ops | _single_input_ops
50+
51+
_fp_to_int_map = {
52+
torch.float16: torch.int16,
53+
torch.bfloat16: torch.int16,
54+
torch.float32: torch.int32,
55+
}
56+
57+
_int_to_fp_map = {
58+
torch.int8: torch.float16, # This doubles the size after casting, but is very unlikely to occur in practice since int8 is only ever used by LOGICAL_SHIFT and CAST/RESCALE ops in PRO-FP.
59+
torch.int16: torch.float16,
60+
torch.int32: torch.float32,
61+
}
62+
63+
def call_operator(self, op, args, kwargs, meta):
64+
if op not in self.targeted_ops:
65+
return super().call_operator(op, args, kwargs, meta)
66+
67+
if op in self._concat_ops:
68+
# Cast to largest dtype
69+
dtypes = [arg.data.dtype for arg in args[0]]
70+
dtype_sizes = [dtype.itemsize for dtype in dtypes]
71+
dtype = dtypes[dtype_sizes.index(max(dtype_sizes))]
72+
else:
73+
dtype = args[0].data.dtype
74+
75+
spec = get_context_spec()
76+
dtype_is_integer = not dtype.is_floating_point and dtype != torch.bool
77+
if dtype_is_integer and not spec.support_integer():
78+
supported_dtype = self._int_to_fp_map.get(dtype, None)
79+
elif dtype.is_floating_point and not spec.support_float():
80+
supported_dtype = self._fp_to_int_map.get(dtype, None)
81+
else:
82+
return super().call_operator(op, args, kwargs, meta)
83+
84+
# CONCATENATE does not support int16 w/o INT16 extension like other ops
85+
if (
86+
op in self._concat_ops
87+
and supported_dtype == torch.int16
88+
and not spec.support_extension("int16")
89+
):
90+
supported_dtype = None
91+
92+
if supported_dtype is None:
93+
raise TypeError(
94+
f"Data type {dtype} of operator {op} is not supported by"
95+
f" {spec}, and casting is currently not supported by {self.__class__.__name__}."
96+
)
97+
98+
if op in self._concat_ops:
99+
x_casted = []
100+
for arg in args[0]:
101+
x_casted.append(
102+
super().call_operator(
103+
self._cast_op,
104+
(arg,),
105+
{"dtype": supported_dtype},
106+
NodeMetadata(arg.node.meta),
107+
updated=True,
108+
)
109+
)
110+
y_casted = super().call_operator(
111+
op, (x_casted, *args[1:]), kwargs, meta, updated=True
112+
)
113+
114+
else:
115+
x_casted = super().call_operator(
116+
self._cast_op,
117+
(args[0],),
118+
{"dtype": supported_dtype},
119+
NodeMetadata(args[0].node.meta),
120+
updated=True,
121+
)
122+
y_casted = super().call_operator(
123+
op, (x_casted, *args[1:]), kwargs, meta, updated=True
124+
)
125+
126+
y = super().call_operator(
127+
self._cast_op, (y_casted,), {"dtype": dtype}, meta, updated=True
128+
)
129+
return y

backends/arm/operators/op_cat.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ def define_node(
3535
inputs: List[TosaArg],
3636
output: TosaArg,
3737
) -> None:
38-
supported_dtypes = [
39-
ts.DType.BOOL,
40-
ts.DType.INT8,
41-
ts.DType.INT32,
42-
ts.DType.FP16,
43-
ts.DType.FP32,
44-
ts.DType.BF16,
45-
]
46-
if self.tosa_spec.support_extension("int16"):
47-
supported_dtypes.append(ts.DType.INT16)
38+
supported_dtypes = [ts.DType.BOOL]
39+
if self.tosa_spec.support_integer():
40+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT32])
41+
if self.tosa_spec.support_extension("int16"):
42+
supported_dtypes.append(ts.DType.INT16)
43+
if self.tosa_spec.support_float():
44+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
45+
if self.tosa_spec.support_extension("bf16"):
46+
supported_dtypes.append(ts.DType.BF16)
4847
validate_num_inputs(self.target, inputs, [1, 2])
4948
input_tosa_args = [TosaArg(arg, self.tosa_spec) for arg in inputs[0].special]
5049
validate_same_dtype(self.target, [*input_tosa_args, output], ts)

backends/arm/operators/op_permute.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,20 @@ def define_node(
110110
inputs: List[TosaArg],
111111
output: TosaArg,
112112
) -> None:
113+
supported_dtypes = [ts.DType.BOOL]
114+
if self.tosa_spec.support_integer():
115+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
116+
if self.tosa_spec.support_float():
117+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
118+
if self.tosa_spec.support_extension("bf16"):
119+
supported_dtypes.append(ts.DType.BF16)
120+
113121
validate_num_inputs(self.target, inputs, 2)
114122
validate_same_dtype(self.target, [inputs[0], output], ts)
115123
validate_valid_dtype(
116124
self.target,
117125
[inputs[0], output],
118-
[
119-
ts.DType.BOOL,
120-
ts.DType.INT8,
121-
ts.DType.INT16,
122-
ts.DType.INT32,
123-
ts.DType.FP16,
124-
ts.DType.FP32,
125-
ts.DType.BF16,
126-
],
126+
supported_dtypes,
127127
self.tosa_spec,
128128
)
129129

backends/arm/operators/op_repeat.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,20 @@ def define_node(
3535
inputs: list[TosaArg],
3636
output: TosaArg,
3737
) -> None:
38+
supported_dtypes = [ts.DType.BOOL]
39+
if self.tosa_spec.support_integer():
40+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
41+
if self.tosa_spec.support_float():
42+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
43+
if self.tosa_spec.support_extension("bf16"):
44+
supported_dtypes.append(ts.DType.BF16)
45+
3846
validate_num_inputs(self.target, inputs, 2)
3947
validate_same_dtype(self.target, [inputs[0], output], ts)
4048
validate_valid_dtype(
4149
self.target,
4250
[inputs[0], output],
43-
[
44-
ts.DType.BOOL,
45-
ts.DType.INT8,
46-
ts.DType.INT16,
47-
ts.DType.INT32,
48-
ts.DType.FP16,
49-
ts.DType.FP32,
50-
ts.DType.BF16,
51-
],
51+
supported_dtypes,
5252
self.tosa_spec,
5353
)
5454

backends/arm/operators/op_tosa_pad.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
NodeVisitor,
1414
register_node_visitor,
1515
)
16+
from executorch.backends.arm.operators.operator_validation_utils import (
17+
validate_num_inputs,
18+
validate_same_dtype,
19+
validate_valid_dtype,
20+
)
1621
from executorch.backends.arm.tosa.mapping import TosaArg
1722

1823

@@ -29,6 +34,23 @@ def define_node(
2934
inputs: List[TosaArg],
3035
output: TosaArg,
3136
) -> None:
37+
supported_dtypes = [ts.DType.BOOL]
38+
if self.tosa_spec.support_integer():
39+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
40+
if self.tosa_spec.support_float():
41+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
42+
if self.tosa_spec.support_extension("bf16"):
43+
supported_dtypes.append(ts.DType.BF16)
44+
45+
validate_num_inputs(self.target, inputs, 2)
46+
validate_same_dtype(self.target, [inputs[0], output], ts)
47+
validate_valid_dtype(
48+
self.target,
49+
[inputs[0], output],
50+
supported_dtypes,
51+
self.tosa_spec,
52+
)
53+
3254
pad_const = tosa_graph.addConst(
3355
[1],
3456
output.dtype,

backends/arm/operators/op_tosa_slice.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
NodeVisitor,
1313
register_node_visitor,
1414
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
19+
)
1520
from executorch.backends.arm.tosa.mapping import TosaArg
1621
from torch.fx import Node
1722

@@ -30,6 +35,22 @@ def define_node(
3035
inputs: List[TosaArg],
3136
output: TosaArg,
3237
) -> None:
38+
supported_dtypes = [ts.DType.BOOL]
39+
if self.tosa_spec.support_integer():
40+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
41+
if self.tosa_spec.support_float():
42+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
43+
if self.tosa_spec.support_extension("bf16"):
44+
supported_dtypes.append(ts.DType.BF16)
45+
46+
validate_num_inputs(self.target, inputs, 3)
47+
validate_same_dtype(self.target, [inputs[0], output], ts)
48+
validate_valid_dtype(
49+
self.target,
50+
[inputs[0], output],
51+
supported_dtypes,
52+
self.tosa_spec,
53+
)
3354

3455
input_node, starts, sizes = inputs
3556
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_tosa_transpose.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,20 @@ def define_node(
4040
inputs: List[TosaArg],
4141
output: TosaArg,
4242
) -> None:
43+
supported_dtypes = [ts.DType.BOOL]
44+
if self.tosa_spec.support_integer():
45+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
46+
if self.tosa_spec.support_float():
47+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
48+
if self.tosa_spec.support_extension("bf16"):
49+
supported_dtypes.append(ts.DType.BF16)
50+
4351
validate_num_inputs(self.target, inputs, 2)
4452
validate_same_dtype(self.target, [inputs[0], output], ts)
4553
validate_valid_dtype(
4654
self.target,
4755
[inputs[0], output],
48-
[
49-
ts.DType.BOOL,
50-
ts.DType.INT8,
51-
ts.DType.INT16,
52-
ts.DType.INT32,
53-
ts.DType.FP16,
54-
ts.DType.FP32,
55-
ts.DType.BF16,
56-
],
56+
supported_dtypes,
5757
self.tosa_spec,
5858
)
5959

backends/arm/operators/op_view.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,20 @@ def define_node(
3535
inputs: List[TosaArg],
3636
output: TosaArg,
3737
) -> None:
38+
supported_dtypes = [ts.DType.BOOL]
39+
if self.tosa_spec.support_integer():
40+
supported_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
41+
if self.tosa_spec.support_float():
42+
supported_dtypes.extend([ts.DType.FP16, ts.DType.FP32])
43+
if self.tosa_spec.support_extension("bf16"):
44+
supported_dtypes.append(ts.DType.BF16)
45+
3846
validate_num_inputs(self.target, inputs, 2)
3947
validate_same_dtype(self.target, [inputs[0], output], ts)
4048
validate_valid_dtype(
4149
self.target,
4250
[inputs[0], output],
43-
[
44-
ts.DType.INT8,
45-
ts.DType.INT16,
46-
ts.DType.INT32,
47-
ts.DType.FP16,
48-
ts.DType.FP32,
49-
ts.DType.BF16,
50-
ts.DType.BOOL,
51-
],
51+
supported_dtypes,
5252
self.tosa_spec,
5353
)
5454

0 commit comments

Comments
 (0)