Skip to content

Commit 332cb65

Browse files
authored
Arm backend: Add TOSA dialect activation ops (pytorch#20019)
Added TOSA dialect operators for: - CLAMP, - ERF, - SIGMOID, - TANH Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent ba2a221 commit 332cb65

4 files changed

Lines changed: 352 additions & 0 deletions

File tree

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import pytest
8+
import torch
9+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
10+
from executorch.backends.arm.tosa.dialect.ops_registration import (
11+
get_registered_tosa_ops,
12+
)
13+
from executorch.backends.arm.tosa.specification import (
14+
TosaLoweringContext,
15+
TosaSpecification,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from torch._subclasses.fake_tensor import FakeTensorMode
19+
20+
21+
def _to_fake(mode: FakeTensorMode, *values):
22+
return [
23+
mode.from_tensor(value) if isinstance(value, torch.Tensor) else value
24+
for value in values
25+
]
26+
27+
28+
@pytest.mark.parametrize(
29+
("op_name", "spec", "input_tensor", "args", "kwargs"),
30+
[
31+
pytest.param(
32+
"CLAMP",
33+
"TOSA-1.1+INT",
34+
torch.randint(-8, 8, (2, 3, 4), dtype=torch.int8),
35+
(-3, 3),
36+
{},
37+
id="CLAMP",
38+
),
39+
pytest.param(
40+
"ERF",
41+
"TOSA-1.1+FP",
42+
torch.randn((2, 3, 4), dtype=torch.float32),
43+
(),
44+
{},
45+
id="ERF",
46+
),
47+
pytest.param(
48+
"SIGMOID",
49+
"TOSA-1.1+FP",
50+
torch.randn((2, 3, 4), dtype=torch.float32),
51+
(),
52+
{},
53+
id="SIGMOID",
54+
),
55+
pytest.param(
56+
"TANH",
57+
"TOSA-1.1+FP",
58+
torch.randn((2, 3, 4), dtype=torch.float32),
59+
(),
60+
{},
61+
id="TANH",
62+
),
63+
],
64+
)
65+
def test_tosa_activation_ops(
66+
op_name: str,
67+
spec: str,
68+
input_tensor: torch.Tensor,
69+
args: tuple[object, ...],
70+
kwargs: dict[str, object],
71+
) -> None:
72+
with TosaLoweringContext(
73+
TosaSpecification.create_from_string(spec)
74+
), FakeTensorMode() as mode:
75+
output = getattr(exir_ops.backend.tosa, op_name).default(
76+
*_to_fake(mode, input_tensor, *args),
77+
**kwargs,
78+
)
79+
80+
assert output.dtype == input_tensor.dtype
81+
assert tuple(output.shape) == tuple(input_tensor.shape)
82+
83+
84+
@pytest.mark.parametrize(
85+
("op", "spec", "expected"),
86+
[
87+
pytest.param(
88+
exir_ops.backend.tosa.ERF.default, "TOSA-1.1+INT", False, id="erf_int"
89+
),
90+
pytest.param(
91+
exir_ops.backend.tosa.SIGMOID.default,
92+
"TOSA-1.1+INT",
93+
False,
94+
id="sigmoid_int",
95+
),
96+
pytest.param(
97+
exir_ops.backend.tosa.TANH.default, "TOSA-1.1+INT", False, id="tanh_int"
98+
),
99+
pytest.param(
100+
exir_ops.backend.tosa.ERF.default, "TOSA-1.1+FP", True, id="erf_fp"
101+
),
102+
pytest.param(
103+
exir_ops.backend.tosa.SIGMOID.default, "TOSA-1.1+FP", True, id="sigmoid_fp"
104+
),
105+
pytest.param(
106+
exir_ops.backend.tosa.TANH.default, "TOSA-1.1+FP", True, id="tanh_fp"
107+
),
108+
],
109+
)
110+
def test_tosa_transcendentals_registered_for_fp_profile_only(
111+
op,
112+
spec: str,
113+
expected: bool,
114+
) -> None:
115+
with TosaLoweringContext(TosaSpecification.create_from_string(spec)):
116+
registered_ops = get_registered_tosa_ops()
117+
118+
assert (op in registered_ops) is expected
119+
120+
121+
@pytest.mark.parametrize(
122+
("op_name", "input_tensor"),
123+
[
124+
pytest.param(
125+
"ERF",
126+
torch.randn((2, 3, 4), dtype=torch.bfloat16),
127+
id="ERF",
128+
),
129+
pytest.param(
130+
"SIGMOID",
131+
torch.randn((2, 3, 4), dtype=torch.bfloat16),
132+
id="SIGMOID",
133+
),
134+
pytest.param(
135+
"TANH",
136+
torch.randn((2, 3, 4), dtype=torch.bfloat16),
137+
id="TANH",
138+
),
139+
],
140+
)
141+
def test_tosa_transcendentals_accept_bfloat16_with_bf16_extension(
142+
op_name: str,
143+
input_tensor: torch.Tensor,
144+
) -> None:
145+
with TosaLoweringContext(
146+
TosaSpecification.create_from_string("TOSA-1.1+FP+bf16")
147+
), FakeTensorMode() as mode:
148+
output = getattr(exir_ops.backend.tosa, op_name).default(
149+
mode.from_tensor(input_tensor)
150+
)
151+
152+
assert output.dtype == torch.bfloat16
153+
assert tuple(output.shape) == tuple(input_tensor.shape)
154+
155+
156+
def test_clamp_rejects_invalid_range() -> None:
157+
sample_input = torch.randint(-8, 8, (2, 3, 4), dtype=torch.int8)
158+
159+
with TosaLoweringContext(
160+
TosaSpecification.create_from_string("TOSA-1.1+INT")
161+
), FakeTensorMode() as mode:
162+
with pytest.raises(
163+
TosaValueError,
164+
match="max_val must be greater than or equal to min_val",
165+
):
166+
exir_ops.backend.tosa.CLAMP.default(
167+
mode.from_tensor(sample_input),
168+
4,
169+
-4,
170+
)
171+
172+
173+
@pytest.mark.parametrize(
174+
("min_val", "max_val", "match"),
175+
[
176+
pytest.param(-1.5, 1.5, "must be an integer", id="non_integral"),
177+
pytest.param(-200, 200, "must be in \\[-128, 127\\]", id="out_of_range"),
178+
],
179+
)
180+
def test_clamp_rejects_invalid_integer_bounds(
181+
min_val: int | float,
182+
max_val: int | float,
183+
match: str,
184+
) -> None:
185+
sample_input = torch.randint(-8, 8, (2, 3, 4), dtype=torch.int8)
186+
187+
with TosaLoweringContext(
188+
TosaSpecification.create_from_string("TOSA-1.1+INT")
189+
), FakeTensorMode() as mode:
190+
with pytest.raises(TosaValueError, match=match):
191+
exir_ops.backend.tosa.CLAMP.default(
192+
mode.from_tensor(sample_input),
193+
min_val,
194+
max_val,
195+
)

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401
7+
activation,
78
avg_pool2d,
89
avg_pool2d_adaptive,
910
conv2d,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
7+
8+
_VALID_NAN_MODES = {"PROPAGATE", "IGNORE"}
9+
10+
11+
def validate_nan_mode(nan_mode: str, op: str) -> None:
12+
if nan_mode not in _VALID_NAN_MODES:
13+
raise TosaValueError(
14+
f"Unsupported nan_mode {nan_mode}. Expected one of {_VALID_NAN_MODES}",
15+
op=op,
16+
)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import math
7+
8+
import torch
9+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
10+
from executorch.backends.arm.tosa.dialect.ops._common import validate_nan_mode
11+
from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op
12+
from executorch.backends.arm.tosa.specification import (
13+
get_context_spec,
14+
TosaSpecification,
15+
)
16+
17+
FP_SPECS = TosaSpecification.all_versions_for_profile("FP")
18+
19+
20+
def _validate_clamp_dtype(dtype: torch.dtype, op: str) -> None:
21+
tosa_spec = get_context_spec()
22+
23+
if dtype == torch.int8:
24+
if not tosa_spec.support_integer():
25+
raise TosaValueError(
26+
f"TOSA spec {tosa_spec} doesn't support int8 for {op}",
27+
op=op,
28+
)
29+
return
30+
31+
if dtype == torch.int16:
32+
if not (tosa_spec.support_integer() and tosa_spec.support_extension("int16")):
33+
raise TosaValueError(
34+
f"TOSA spec {tosa_spec} doesn't support int16 for {op}",
35+
op=op,
36+
)
37+
return
38+
39+
_validate_float_dtype(dtype, op)
40+
return
41+
42+
raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op)
43+
44+
45+
def _validate_float_dtype(dtype: torch.dtype, op: str) -> None:
46+
tosa_spec = get_context_spec()
47+
48+
if dtype in (torch.float16, torch.float32):
49+
if not tosa_spec.support_float():
50+
raise TosaValueError(
51+
f"TOSA spec {tosa_spec} doesn't support {dtype} for {op}",
52+
op=op,
53+
)
54+
return
55+
56+
if dtype == torch.bfloat16:
57+
if not (tosa_spec.support_float() and tosa_spec.support_extension("bf16")):
58+
raise TosaValueError(
59+
f"TOSA spec {tosa_spec} doesn't support bfloat16 for {op}",
60+
op=op,
61+
)
62+
return
63+
64+
raise TosaValueError(f"Unsupported dtype {dtype} for {op}", op=op)
65+
66+
67+
def _validate_integer_clamp_bounds(
68+
dtype: torch.dtype,
69+
min_val,
70+
max_val,
71+
) -> None:
72+
if dtype not in (torch.int8, torch.int16):
73+
return
74+
75+
dtype_info = torch.iinfo(dtype)
76+
for name, value in (("min_val", min_val), ("max_val", max_val)):
77+
if not isinstance(value, int) or isinstance(value, bool):
78+
raise TosaValueError(
79+
f"{name} must be an integer for {dtype} CLAMP",
80+
op="CLAMP",
81+
)
82+
if value < dtype_info.min or value > dtype_info.max:
83+
raise TosaValueError(
84+
f"{name} must be in [{dtype_info.min}, {dtype_info.max}] for {dtype} CLAMP",
85+
op="CLAMP",
86+
)
87+
88+
89+
@register_fake_tosa_op(
90+
'CLAMP(Tensor input, Scalar min_val, Scalar max_val, *, str nan_mode="PROPAGATE") -> Tensor',
91+
TosaSpecification.all_versions_and_profiles(),
92+
)
93+
def CLAMP(
94+
input: torch.Tensor,
95+
min_val,
96+
max_val,
97+
*,
98+
nan_mode: str = "PROPAGATE",
99+
) -> torch.Tensor:
100+
validate_nan_mode(nan_mode, "CLAMP")
101+
_validate_clamp_dtype(input.dtype, "CLAMP")
102+
_validate_integer_clamp_bounds(input.dtype, min_val, max_val)
103+
104+
if isinstance(min_val, float) and math.isnan(min_val):
105+
raise TosaValueError("min_val cannot be NaN", op="CLAMP")
106+
if isinstance(max_val, float) and math.isnan(max_val):
107+
raise TosaValueError("max_val cannot be NaN", op="CLAMP")
108+
if min_val > max_val:
109+
raise TosaValueError(
110+
"max_val must be greater than or equal to min_val", op="CLAMP"
111+
)
112+
113+
return torch.empty_like(input, dtype=input.dtype)
114+
115+
116+
@register_fake_tosa_op(
117+
"ERF(Tensor input) -> Tensor",
118+
FP_SPECS,
119+
)
120+
def ERF(input: torch.Tensor) -> torch.Tensor:
121+
_validate_float_dtype(input.dtype, "ERF")
122+
return torch.empty_like(input, dtype=input.dtype)
123+
124+
125+
@register_fake_tosa_op(
126+
"SIGMOID(Tensor input) -> Tensor",
127+
FP_SPECS,
128+
)
129+
def SIGMOID(input: torch.Tensor) -> torch.Tensor:
130+
_validate_float_dtype(input.dtype, "SIGMOID")
131+
return torch.empty_like(input, dtype=input.dtype)
132+
133+
134+
@register_fake_tosa_op(
135+
"TANH(Tensor input) -> Tensor",
136+
FP_SPECS,
137+
)
138+
def TANH(input: torch.Tensor) -> torch.Tensor:
139+
_validate_float_dtype(input.dtype, "TANH")
140+
return torch.empty_like(input, dtype=input.dtype)

0 commit comments

Comments
 (0)