Skip to content

Commit f0d9991

Browse files
Arm backend: Add TOSA dialect reduction ops (pytorch#19937)
Register fake TOSA dialect implementations for REDUCE_ALL, REDUCE_ANY, REDUCE_MAX, REDUCE_MIN, REDUCE_PRODUCT, and REDUCE_SUM. The new fake ops preserve the reduced axis in the output shape and validate input rank, axis bounds, supported dtypes, profile and extension gating, and NaN propagation mode where required by the TOSA spec. Add reduction-op dialect tests covering valid shape propagation and the main rejection cases for invalid bool, integer, and narrow-integer inputs. cc @digantdesai @freddan80 @per @zingo @mansnils @Sebastian-Larsson @robell @rascani Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent 7871a9b commit f0d9991

3 files changed

Lines changed: 321 additions & 0 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import pytest
8+
import torch
9+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
10+
from executorch.backends.arm.tosa.specification import (
11+
TosaLoweringContext,
12+
TosaSpecification,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from torch._subclasses.fake_tensor import FakeTensorMode
16+
17+
18+
@pytest.mark.parametrize(
19+
"op_name,input_tensor,kwargs,expected_shape",
20+
[
21+
(
22+
"REDUCE_ALL",
23+
torch.tensor([[[True, False], [True, True]]]),
24+
{"axis": 1},
25+
(1, 1, 2),
26+
),
27+
(
28+
"REDUCE_ANY",
29+
torch.tensor([[[True, False], [False, False]]]),
30+
{"axis": 2},
31+
(1, 2, 1),
32+
),
33+
(
34+
"REDUCE_MAX",
35+
torch.randint(-8, 8, (2, 3, 4), dtype=torch.int32),
36+
{"axis": 0, "nan_mode": "PROPAGATE"},
37+
(1, 3, 4),
38+
),
39+
(
40+
"REDUCE_MIN",
41+
torch.randn((2, 3, 4), dtype=torch.float32),
42+
{"axis": 2, "nan_mode": "IGNORE"},
43+
(2, 3, 1),
44+
),
45+
(
46+
"REDUCE_PRODUCT",
47+
torch.randn((2, 3, 4), dtype=torch.float32),
48+
{"axis": 1},
49+
(2, 1, 4),
50+
),
51+
(
52+
"REDUCE_SUM",
53+
torch.randint(-8, 8, (2, 3, 4), dtype=torch.int32),
54+
{"axis": 1},
55+
(2, 1, 4),
56+
),
57+
],
58+
)
59+
def test_reduction_ops(op_name, input_tensor, kwargs, expected_shape):
60+
spec = (
61+
"TOSA-1.1+FP+bf16+int64"
62+
if input_tensor.dtype.is_floating_point
63+
else "TOSA-1.1+INT+int16+int64"
64+
)
65+
with TosaLoweringContext(
66+
TosaSpecification.create_from_string(spec)
67+
), FakeTensorMode() as mode:
68+
op = getattr(exir_ops.backend.tosa, op_name).default
69+
output = op(mode.from_tensor(input_tensor), **kwargs)
70+
71+
assert output.dtype == input_tensor.dtype
72+
assert tuple(output.shape) == expected_shape
73+
74+
75+
def test_reduce_all_rejects_non_bool():
76+
with TosaLoweringContext(
77+
TosaSpecification.create_from_string("TOSA-1.1+INT")
78+
), FakeTensorMode() as mode:
79+
with pytest.raises(TosaValueError, match="requires bool input"):
80+
exir_ops.backend.tosa.REDUCE_ALL.default(
81+
mode.from_tensor(torch.ones((2, 2), dtype=torch.int32)), axis=1
82+
)
83+
84+
85+
def test_reduce_product_rejects_integer_input():
86+
with TosaLoweringContext(
87+
TosaSpecification.create_from_string("TOSA-1.1+INT")
88+
), FakeTensorMode() as mode:
89+
with pytest.raises(TosaValueError, match="floating-point input"):
90+
exir_ops.backend.tosa.REDUCE_PRODUCT.default(
91+
mode.from_tensor(torch.ones((2, 2), dtype=torch.int32)), axis=1
92+
)
93+
94+
95+
@pytest.mark.parametrize(
96+
"op_name,dtype", [("REDUCE_MAX", torch.float32), ("REDUCE_MIN", torch.int32)]
97+
)
98+
def test_reduce_minmax_default_nan_mode(op_name: str, dtype: torch.dtype):
99+
spec = "TOSA-1.1+FP" if dtype.is_floating_point else "TOSA-1.1+INT"
100+
with TosaLoweringContext(
101+
TosaSpecification.create_from_string(spec)
102+
), FakeTensorMode() as mode:
103+
op = getattr(exir_ops.backend.tosa, op_name).default
104+
output = op(mode.from_tensor(torch.ones((2, 2), dtype=dtype)), axis=1)
105+
106+
assert output.dtype == dtype
107+
assert tuple(output.shape) == (2, 1)
108+
109+
110+
@pytest.mark.parametrize("op_name", ["REDUCE_MAX", "REDUCE_MIN"])
111+
def test_reduce_minmax_rejects_invalid_nan_mode(op_name: str):
112+
with TosaLoweringContext(
113+
TosaSpecification.create_from_string("TOSA-1.1+FP")
114+
), FakeTensorMode() as mode:
115+
op = getattr(exir_ops.backend.tosa, op_name).default
116+
with pytest.raises(TosaValueError, match="Invalid nan_mode"):
117+
op(
118+
mode.from_tensor(torch.ones((2, 2), dtype=torch.float32)),
119+
axis=1,
120+
nan_mode="INVALID_MODE",
121+
)
122+
123+
124+
@pytest.mark.parametrize("dtype", [torch.int8, torch.int16])
125+
def test_reduce_sum_rejects_narrow_integer_inputs(dtype: torch.dtype):
126+
spec = "TOSA-1.1+INT+int16" if dtype == torch.int16 else "TOSA-1.1+INT"
127+
with TosaLoweringContext(
128+
TosaSpecification.create_from_string(spec)
129+
), FakeTensorMode() as mode:
130+
with pytest.raises(TosaValueError, match="Unsupported dtype"):
131+
exir_ops.backend.tosa.REDUCE_SUM.default(
132+
mode.from_tensor(torch.ones((2, 2), dtype=dtype)),
133+
axis=1,
134+
)

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
max_pool2d,
1717
max_pool2d_adaptive,
1818
pad,
19+
reduction_ops,
1920
rescale,
2021
resize,
2122
scatter,
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
9+
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
10+
from executorch.backends.arm.tosa.specification import (
11+
get_context_spec,
12+
TosaSpecification,
13+
)
14+
15+
16+
def _validate_axis(x: torch.Tensor, axis: int, op: str) -> None:
17+
if x.dim() < 1:
18+
raise TosaValueError(f"{op} requires rank >= 1 input", op=op)
19+
if axis < 0 or axis >= x.dim():
20+
raise TosaValueError(
21+
f"{op} axis {axis} is out of range for rank {x.dim()}",
22+
op=op,
23+
)
24+
25+
26+
def _reduce_shape(x: torch.Tensor, axis: int) -> list[int | torch.SymInt]:
27+
output_shape: list[int | torch.SymInt] = list(x.shape)
28+
output_shape[axis] = 1
29+
return output_shape
30+
31+
32+
def _validate_bool_dtype(x: torch.Tensor, op: str) -> None:
33+
if x.dtype != torch.bool:
34+
raise TosaValueError(f"{op} requires bool input, got {x.dtype}", op=op)
35+
36+
37+
def _validate_float_integer_dtype(x: torch.Tensor, op: str) -> None:
38+
tosa_spec = get_context_spec()
39+
supported_int_dtypes = {torch.int8, torch.int16, torch.int32}
40+
supported_float_dtypes = {torch.float16, torch.float32}
41+
42+
if tosa_spec.support_extension("int64"):
43+
supported_int_dtypes.add(torch.int64)
44+
if tosa_spec.support_extension("bf16"):
45+
supported_float_dtypes.add(torch.bfloat16)
46+
47+
if x.dtype in supported_int_dtypes:
48+
if not tosa_spec.support_integer():
49+
raise TosaValueError(
50+
f"TOSA spec {tosa_spec} doesn't support integer reductions",
51+
op=op,
52+
)
53+
return
54+
55+
if x.dtype in supported_float_dtypes:
56+
if not tosa_spec.support_float():
57+
raise TosaValueError(
58+
f"TOSA spec {tosa_spec} doesn't support floating-point reductions",
59+
op=op,
60+
)
61+
return
62+
63+
raise TosaValueError(f"Unsupported dtype {x.dtype} for {op}", op=op)
64+
65+
66+
def _validate_reduce_sum_dtype(x: torch.Tensor) -> None:
67+
tosa_spec = get_context_spec()
68+
supported_int_dtypes = {torch.int32}
69+
supported_float_dtypes = {torch.float16, torch.float32}
70+
71+
if tosa_spec.support_extension("int64"):
72+
supported_int_dtypes.add(torch.int64)
73+
if tosa_spec.support_extension("bf16"):
74+
supported_float_dtypes.add(torch.bfloat16)
75+
76+
if x.dtype in supported_int_dtypes:
77+
if not tosa_spec.support_integer():
78+
raise TosaValueError(
79+
f"TOSA spec {tosa_spec} doesn't support integer reductions",
80+
op="REDUCE_SUM",
81+
)
82+
return
83+
84+
if x.dtype in supported_float_dtypes:
85+
if not tosa_spec.support_float():
86+
raise TosaValueError(
87+
f"TOSA spec {tosa_spec} doesn't support floating-point reductions",
88+
op="REDUCE_SUM",
89+
)
90+
return
91+
92+
raise TosaValueError(
93+
f"Unsupported dtype {x.dtype} for REDUCE_SUM",
94+
op="REDUCE_SUM",
95+
)
96+
97+
98+
def _validate_product_dtype(x: torch.Tensor, op: str) -> None:
99+
tosa_spec = get_context_spec()
100+
supported_dtypes = {torch.float16, torch.float32}
101+
if tosa_spec.support_extension("bf16"):
102+
supported_dtypes.add(torch.bfloat16)
103+
104+
if x.dtype not in supported_dtypes:
105+
raise TosaValueError(
106+
f"{op} requires floating-point input, got {x.dtype}", op=op
107+
)
108+
if not tosa_spec.support_float():
109+
raise TosaValueError(
110+
f"TOSA spec {tosa_spec} doesn't support floating-point reductions",
111+
op=op,
112+
)
113+
114+
115+
def _validate_nan_mode(nan_mode: str, op: str) -> None:
116+
if nan_mode not in ("PROPAGATE", "IGNORE"):
117+
raise TosaValueError(
118+
f"Invalid nan_mode {nan_mode}, must be PROPAGATE or IGNORE",
119+
op=op,
120+
)
121+
122+
123+
@register_fake_tosa_op(
124+
"REDUCE_ALL(Tensor input, *, int axis) -> Tensor",
125+
TosaSpecification.all_versions_and_profiles(),
126+
)
127+
def REDUCE_ALL(x: torch.Tensor, *, axis: int) -> torch.Tensor:
128+
_validate_axis(x, axis, "REDUCE_ALL")
129+
_validate_bool_dtype(x, "REDUCE_ALL")
130+
return torch.empty(size=_reduce_shape(x, axis), dtype=x.dtype)
131+
132+
133+
@register_fake_tosa_op(
134+
"REDUCE_ANY(Tensor input, *, int axis) -> Tensor",
135+
TosaSpecification.all_versions_and_profiles(),
136+
)
137+
def REDUCE_ANY(x: torch.Tensor, *, axis: int) -> torch.Tensor:
138+
_validate_axis(x, axis, "REDUCE_ANY")
139+
_validate_bool_dtype(x, "REDUCE_ANY")
140+
return torch.empty(size=_reduce_shape(x, axis), dtype=x.dtype)
141+
142+
143+
@register_fake_tosa_op(
144+
'REDUCE_MAX(Tensor input, *, int axis, str nan_mode="PROPAGATE") -> Tensor',
145+
TosaSpecification.all_versions_and_profiles(),
146+
)
147+
def REDUCE_MAX(
148+
x: torch.Tensor, *, axis: int, nan_mode: str = "PROPAGATE"
149+
) -> torch.Tensor:
150+
_validate_axis(x, axis, "REDUCE_MAX")
151+
_validate_float_integer_dtype(x, "REDUCE_MAX")
152+
_validate_nan_mode(nan_mode, "REDUCE_MAX")
153+
return torch.empty(size=_reduce_shape(x, axis), dtype=x.dtype)
154+
155+
156+
@register_fake_tosa_op(
157+
'REDUCE_MIN(Tensor input, *, int axis, str nan_mode="PROPAGATE") -> Tensor',
158+
TosaSpecification.all_versions_and_profiles(),
159+
)
160+
def REDUCE_MIN(
161+
x: torch.Tensor, *, axis: int, nan_mode: str = "PROPAGATE"
162+
) -> torch.Tensor:
163+
_validate_axis(x, axis, "REDUCE_MIN")
164+
_validate_float_integer_dtype(x, "REDUCE_MIN")
165+
_validate_nan_mode(nan_mode, "REDUCE_MIN")
166+
return torch.empty(size=_reduce_shape(x, axis), dtype=x.dtype)
167+
168+
169+
@register_fake_tosa_op(
170+
"REDUCE_PRODUCT(Tensor input, *, int axis) -> Tensor",
171+
TosaSpecification.all_versions_and_profiles(),
172+
)
173+
def REDUCE_PRODUCT(x: torch.Tensor, *, axis: int) -> torch.Tensor:
174+
_validate_axis(x, axis, "REDUCE_PRODUCT")
175+
_validate_product_dtype(x, "REDUCE_PRODUCT")
176+
return torch.empty(size=_reduce_shape(x, axis), dtype=x.dtype)
177+
178+
179+
@register_fake_tosa_op(
180+
"REDUCE_SUM(Tensor input, *, int axis) -> Tensor",
181+
TosaSpecification.all_versions_and_profiles(),
182+
)
183+
def REDUCE_SUM(x: torch.Tensor, *, axis: int) -> torch.Tensor:
184+
_validate_axis(x, axis, "REDUCE_SUM")
185+
_validate_reduce_sum_dtype(x)
186+
return torch.empty(size=_reduce_shape(x, axis), dtype=x.dtype)

0 commit comments

Comments
 (0)