Skip to content

Commit 91be26d

Browse files
authored
Arm backend: Add TOSA dialect unary elementwise ops (pytorch#20017)
Added TOSA dialect operators for: - ABS - BITWISE_NOT - CEIL - CLZ - COS - EXP - FLOOR - LOG - LOGICAL_NOT - NEGATE - RECIPROCAL - RSQRT - SIN Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 332cb65 commit 91be26d

3 files changed

Lines changed: 619 additions & 0 deletions

File tree

Lines changed: 394 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import pytest
8+
import torch
9+
from executorch.backends.arm.tosa.dialect.lib import TosaValueError
10+
from executorch.backends.arm.tosa.dialect.ops_registration import (
11+
get_registered_tosa_ops,
12+
)
13+
from executorch.backends.arm.tosa.specification import (
14+
TosaLoweringContext,
15+
TosaSpecification,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from torch._subclasses.fake_tensor import FakeTensorMode
19+
20+
21+
@pytest.mark.parametrize(
22+
("op_name", "spec", "input_tensor"),
23+
[
24+
pytest.param(
25+
"ABS",
26+
"TOSA-1.1+INT",
27+
torch.randint(1, 16, (2, 3), dtype=torch.int32),
28+
id="ABS",
29+
),
30+
pytest.param(
31+
"BITWISE_NOT",
32+
"TOSA-1.1+INT",
33+
torch.randint(-8, 8, (2, 3), dtype=torch.int8),
34+
id="BITWISE_NOT",
35+
),
36+
pytest.param(
37+
"BITWISE_NOT",
38+
"TOSA-1.1+INT",
39+
torch.randint(-8, 8, (2, 3), dtype=torch.int16),
40+
id="BITWISE_NOT_INT16",
41+
),
42+
pytest.param(
43+
"CEIL",
44+
"TOSA-1.1+FP",
45+
torch.randn((2, 3), dtype=torch.float32),
46+
id="CEIL",
47+
),
48+
pytest.param(
49+
"CLZ",
50+
"TOSA-1.1+INT",
51+
torch.randint(1, 16, (2, 3), dtype=torch.int32),
52+
id="CLZ",
53+
),
54+
pytest.param(
55+
"COS",
56+
"TOSA-1.1+FP",
57+
torch.randn((2, 3), dtype=torch.float32),
58+
id="COS",
59+
),
60+
pytest.param(
61+
"EXP",
62+
"TOSA-1.1+FP",
63+
torch.randn((2, 3), dtype=torch.float32),
64+
id="EXP",
65+
),
66+
pytest.param(
67+
"FLOOR",
68+
"TOSA-1.1+FP",
69+
torch.randn((2, 3), dtype=torch.float32),
70+
id="FLOOR",
71+
),
72+
pytest.param(
73+
"LOG",
74+
"TOSA-1.1+FP",
75+
torch.randn((2, 3), dtype=torch.float32).abs() + 1.0,
76+
id="LOG",
77+
),
78+
pytest.param(
79+
"LOGICAL_NOT",
80+
"TOSA-1.1+FP",
81+
torch.tensor([[True, False], [False, True]], dtype=torch.bool),
82+
id="LOGICAL_NOT",
83+
),
84+
pytest.param(
85+
"NEGATE",
86+
"TOSA-1.1+INT",
87+
torch.randint(-8, 8, (2, 3), dtype=torch.int32),
88+
id="NEGATE",
89+
),
90+
pytest.param(
91+
"NEGATE",
92+
"TOSA-1.1+INT",
93+
torch.randint(-8, 8, (2, 3), dtype=torch.int16),
94+
id="NEGATE_INT16",
95+
),
96+
pytest.param(
97+
"NEGATE",
98+
"TOSA-1.1+FP",
99+
torch.randn((2, 3), dtype=torch.float32),
100+
id="NEGATE_FP32",
101+
),
102+
pytest.param(
103+
"RECIPROCAL",
104+
"TOSA-1.1+FP",
105+
torch.randn((2, 3), dtype=torch.float32).abs() + 1.0,
106+
id="RECIPROCAL",
107+
),
108+
pytest.param(
109+
"RSQRT",
110+
"TOSA-1.1+FP",
111+
torch.randn((2, 3), dtype=torch.float32).abs() + 1.0,
112+
id="RSQRT",
113+
),
114+
pytest.param(
115+
"SIN",
116+
"TOSA-1.1+FP",
117+
torch.randn((2, 3), dtype=torch.float32),
118+
id="SIN",
119+
),
120+
],
121+
)
122+
def test_tosa_unary_ops(
123+
op_name: str,
124+
spec: str,
125+
input_tensor: torch.Tensor,
126+
) -> None:
127+
with TosaLoweringContext(
128+
TosaSpecification.create_from_string(spec)
129+
), FakeTensorMode() as mode:
130+
output = getattr(exir_ops.backend.tosa, op_name).default(
131+
mode.from_tensor(input_tensor)
132+
)
133+
134+
assert output.dtype == input_tensor.dtype
135+
assert tuple(output.shape) == tuple(input_tensor.shape)
136+
137+
138+
@pytest.mark.parametrize(
139+
("op", "spec", "expected"),
140+
[
141+
pytest.param(
142+
exir_ops.backend.tosa.BITWISE_NOT.default,
143+
"TOSA-1.1+INT",
144+
True,
145+
id="bitwise_not_int",
146+
),
147+
pytest.param(
148+
exir_ops.backend.tosa.BITWISE_NOT.default,
149+
"TOSA-1.1+FP",
150+
False,
151+
id="bitwise_not_fp",
152+
),
153+
pytest.param(
154+
exir_ops.backend.tosa.CLZ.default,
155+
"TOSA-1.1+INT",
156+
True,
157+
id="clz_int",
158+
),
159+
pytest.param(
160+
exir_ops.backend.tosa.CLZ.default,
161+
"TOSA-1.1+FP",
162+
False,
163+
id="clz_fp",
164+
),
165+
],
166+
)
167+
def test_tosa_integer_unary_ops_registered_for_int_profile_only(
168+
op,
169+
spec: str,
170+
expected: bool,
171+
) -> None:
172+
with TosaLoweringContext(TosaSpecification.create_from_string(spec)):
173+
registered_ops = get_registered_tosa_ops()
174+
175+
assert (op in registered_ops) is expected
176+
177+
178+
@pytest.mark.parametrize(
179+
("op", "spec", "expected"),
180+
[
181+
pytest.param(
182+
exir_ops.backend.tosa.CEIL.default,
183+
"TOSA-1.1+INT",
184+
False,
185+
id="ceil_int",
186+
),
187+
pytest.param(
188+
exir_ops.backend.tosa.CEIL.default,
189+
"TOSA-1.1+FP",
190+
True,
191+
id="ceil_fp",
192+
),
193+
pytest.param(
194+
exir_ops.backend.tosa.COS.default,
195+
"TOSA-1.1+INT",
196+
False,
197+
id="cos_int",
198+
),
199+
pytest.param(
200+
exir_ops.backend.tosa.COS.default,
201+
"TOSA-1.1+FP",
202+
True,
203+
id="cos_fp",
204+
),
205+
pytest.param(
206+
exir_ops.backend.tosa.EXP.default,
207+
"TOSA-1.1+INT",
208+
False,
209+
id="exp_int",
210+
),
211+
pytest.param(
212+
exir_ops.backend.tosa.EXP.default,
213+
"TOSA-1.1+FP",
214+
True,
215+
id="exp_fp",
216+
),
217+
pytest.param(
218+
exir_ops.backend.tosa.FLOOR.default,
219+
"TOSA-1.1+INT",
220+
False,
221+
id="floor_int",
222+
),
223+
pytest.param(
224+
exir_ops.backend.tosa.FLOOR.default,
225+
"TOSA-1.1+FP",
226+
True,
227+
id="floor_fp",
228+
),
229+
pytest.param(
230+
exir_ops.backend.tosa.LOG.default,
231+
"TOSA-1.1+INT",
232+
False,
233+
id="log_int",
234+
),
235+
pytest.param(
236+
exir_ops.backend.tosa.LOG.default,
237+
"TOSA-1.1+FP",
238+
True,
239+
id="log_fp",
240+
),
241+
pytest.param(
242+
exir_ops.backend.tosa.RECIPROCAL.default,
243+
"TOSA-1.1+INT",
244+
False,
245+
id="reciprocal_int",
246+
),
247+
pytest.param(
248+
exir_ops.backend.tosa.RECIPROCAL.default,
249+
"TOSA-1.1+FP",
250+
True,
251+
id="reciprocal_fp",
252+
),
253+
pytest.param(
254+
exir_ops.backend.tosa.RSQRT.default,
255+
"TOSA-1.1+INT",
256+
False,
257+
id="rsqrt_int",
258+
),
259+
pytest.param(
260+
exir_ops.backend.tosa.RSQRT.default,
261+
"TOSA-1.1+FP",
262+
True,
263+
id="rsqrt_fp",
264+
),
265+
pytest.param(
266+
exir_ops.backend.tosa.SIN.default,
267+
"TOSA-1.1+INT",
268+
False,
269+
id="sin_int",
270+
),
271+
pytest.param(
272+
exir_ops.backend.tosa.SIN.default,
273+
"TOSA-1.1+FP",
274+
True,
275+
id="sin_fp",
276+
),
277+
],
278+
)
279+
def test_tosa_float_unary_ops_registered_for_fp_profile_only(
280+
op,
281+
spec: str,
282+
expected: bool,
283+
) -> None:
284+
with TosaLoweringContext(TosaSpecification.create_from_string(spec)):
285+
registered_ops = get_registered_tosa_ops()
286+
287+
assert (op in registered_ops) is expected
288+
289+
290+
@pytest.mark.parametrize(
291+
("spec", "expected"),
292+
[
293+
pytest.param("TOSA-1.1+INT", True, id="negate_int"),
294+
pytest.param("TOSA-1.1+FP", True, id="negate_fp"),
295+
],
296+
)
297+
def test_tosa_negate_registered_for_int_and_fp_profiles(
298+
spec: str,
299+
expected: bool,
300+
) -> None:
301+
with TosaLoweringContext(TosaSpecification.create_from_string(spec)):
302+
registered_ops = get_registered_tosa_ops()
303+
304+
assert (exir_ops.backend.tosa.NEGATE.default in registered_ops) is expected
305+
306+
307+
@pytest.mark.parametrize(
308+
("op_name", "input_tensor"),
309+
[
310+
pytest.param(
311+
"CEIL",
312+
torch.randn((2, 3), dtype=torch.bfloat16),
313+
id="CEIL",
314+
),
315+
pytest.param(
316+
"COS",
317+
torch.randn((2, 3), dtype=torch.bfloat16),
318+
id="COS",
319+
),
320+
pytest.param(
321+
"EXP",
322+
torch.randn((2, 3), dtype=torch.bfloat16),
323+
id="EXP",
324+
),
325+
pytest.param(
326+
"FLOOR",
327+
torch.randn((2, 3), dtype=torch.bfloat16),
328+
id="FLOOR",
329+
),
330+
pytest.param(
331+
"LOG",
332+
torch.randn((2, 3), dtype=torch.bfloat16).abs() + 1.0,
333+
id="LOG",
334+
),
335+
pytest.param(
336+
"NEGATE",
337+
torch.randn((2, 3), dtype=torch.bfloat16),
338+
id="NEGATE",
339+
),
340+
],
341+
)
342+
def test_tosa_float_unary_ops_accept_bfloat16_with_bf16_extension(
343+
op_name: str,
344+
input_tensor: torch.Tensor,
345+
) -> None:
346+
with TosaLoweringContext(
347+
TosaSpecification.create_from_string("TOSA-1.1+FP+bf16")
348+
), FakeTensorMode() as mode:
349+
output = getattr(exir_ops.backend.tosa, op_name).default(
350+
mode.from_tensor(input_tensor)
351+
)
352+
353+
assert output.dtype == torch.bfloat16
354+
assert tuple(output.shape) == tuple(input_tensor.shape)
355+
356+
357+
def test_negate_rejects_bfloat16_without_bf16_extension() -> None:
358+
sample_input = torch.randn((2, 3), dtype=torch.bfloat16)
359+
360+
with TosaLoweringContext(
361+
TosaSpecification.create_from_string("TOSA-1.1+FP")
362+
), FakeTensorMode() as mode:
363+
with pytest.raises(TosaValueError, match="doesn't support bfloat16"):
364+
exir_ops.backend.tosa.NEGATE.default(mode.from_tensor(sample_input))
365+
366+
367+
def test_abs_rejects_int8() -> None:
368+
sample_input = torch.randint(-8, 8, (2, 3), dtype=torch.int8)
369+
370+
with TosaLoweringContext(
371+
TosaSpecification.create_from_string("TOSA-1.1+INT")
372+
), FakeTensorMode() as mode:
373+
with pytest.raises(TosaValueError, match="Unsupported dtype"):
374+
exir_ops.backend.tosa.ABS.default(mode.from_tensor(sample_input))
375+
376+
377+
def test_floor_requires_float_profile() -> None:
378+
sample_input = torch.randn((2, 3), dtype=torch.float32)
379+
380+
with TosaLoweringContext(
381+
TosaSpecification.create_from_string("TOSA-1.1+INT")
382+
), FakeTensorMode() as mode:
383+
with pytest.raises(TosaValueError, match="doesn't support"):
384+
exir_ops.backend.tosa.FLOOR.default(mode.from_tensor(sample_input))
385+
386+
387+
def test_logical_not_rejects_non_bool() -> None:
388+
sample_input = torch.randint(-8, 8, (2, 3), dtype=torch.int8)
389+
390+
with TosaLoweringContext(
391+
TosaSpecification.create_from_string("TOSA-1.1+INT")
392+
), FakeTensorMode() as mode:
393+
with pytest.raises(TosaValueError, match="requires bool inputs"):
394+
exir_ops.backend.tosa.LOGICAL_NOT.default(mode.from_tensor(sample_input))

0 commit comments

Comments
 (0)