Skip to content

Commit e98e74e

Browse files
authored
Arm backend: Add TOSA dialect ARGMAX node visitor (#20418)
- Adds ARGMAX to ExirToTosaPass - Uses _DIALECT_SUBSTITUTIONS for TOSA Activation ops Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
1 parent 06f05ca commit e98e74e

8 files changed

Lines changed: 370 additions & 28 deletions

File tree

backends/arm/_passes/aten_to_tosa_activation_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,21 @@ def rewrite_clamp(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | Non
128128
exir_ops.backend.tosa.CLAMP.default,
129129
(node.args[0], *min_max_args),
130130
)
131+
132+
133+
def get_activation_replacement(
134+
node: Node, pass_: AtenToDialectPass
135+
) -> DialectNodeSpec | None:
136+
# Dispatch activation rewrites from their ATen target to the matching TOSA
137+
# dialect node builder.
138+
match node.target:
139+
case exir_ops.edge.aten.clamp.default:
140+
return rewrite_clamp(node, pass_)
141+
case exir_ops.edge.aten.erf.default:
142+
return rewrite_erf(node, pass_)
143+
case exir_ops.edge.aten.sigmoid.default:
144+
return rewrite_sigmoid(node, pass_)
145+
case exir_ops.edge.aten.tanh.default:
146+
return rewrite_tanh(node, pass_)
147+
case _:
148+
return None
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
from executorch.backends.transforms.aten_to_dialect_pass import (
9+
AtenToDialectPass,
10+
DialectNodeSpec,
11+
)
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx import Node
14+
15+
16+
def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
17+
input_node = cast(Node, node.args[0])
18+
dim = cast(int, node.kwargs["dim"] if "dim" in node.kwargs else node.args[1])
19+
if dim < 0:
20+
dim += len(input_node.meta["val"].shape)
21+
22+
return DialectNodeSpec(
23+
exir_ops.backend.tosa.ARGMAX.default,
24+
(input_node, dim),
25+
{},
26+
)

backends/arm/_passes/exir_to_tosa_pass.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,38 @@
55

66
import executorch.backends.arm.tosa.dialect # noqa: F401
77
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
8-
rewrite_clamp,
9-
rewrite_erf,
10-
rewrite_sigmoid,
11-
rewrite_tanh,
8+
get_activation_replacement,
9+
)
10+
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
11+
from executorch.backends.transforms.aten_to_dialect_pass import (
12+
AtenToDialectPass,
13+
DialectNodeSpec,
1214
)
13-
from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass
1415
from executorch.exir.dialects._ops import ops as exir_ops
16+
from torch.fx import Node
1517

1618

1719
class ExirToTosaPass(AtenToDialectPass):
1820
"""Rewrite simple EXIR ops to equivalent backend TOSA dialect ops.
1921
20-
Rewrite functions are grouped by op category and registered with the shared
21-
ATen-to-dialect pass infrastructure.
22+
Rewrite functions are registered with the shared ATen-to-dialect pass
23+
infrastructure.
2224
2325
"""
2426

2527

26-
_ACTIVATION_FUNCTION_REWRITES = {
27-
exir_ops.edge.aten.clamp.default: rewrite_clamp,
28-
exir_ops.edge.aten.erf.default: rewrite_erf,
29-
exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid,
30-
exir_ops.edge.aten.tanh.default: rewrite_tanh,
31-
}
28+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
29+
def _get_tensor_operators_replacement(
30+
node: Node, pass_: AtenToDialectPass
31+
) -> DialectNodeSpec:
32+
return rewrite_argmax(node, pass_)
3233

33-
_DIRECT_REWRITE_CATEGORIES = {
34-
"activation_functions": _ACTIVATION_FUNCTION_REWRITES,
35-
}
3634

37-
# Register each category's ATen targets with the function that builds the
38-
# corresponding TOSA dialect node spec.
39-
for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values():
40-
for _edge_target, _rewrite_fn in _rewrite_category.items():
41-
ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn)
35+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
36+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
37+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
38+
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
39+
def _get_activation_replacement(
40+
node: Node, pass_: AtenToDialectPass
41+
) -> DialectNodeSpec | None:
42+
return get_activation_replacement(node, pass_)

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
100100
exir_ops.edge.aten.pad.default,
101101
exir_ops.edge.aten.constant_pad_nd.default,
102+
exir_ops.edge.aten.argmax.default,
102103
exir_ops.edge.aten.amax.default,
103104
exir_ops.edge.aten.amin.default,
104105
exir_ops.edge.aten.eye.default,
@@ -238,6 +239,7 @@
238239
operator.getitem,
239240
exir_ops.edge.aten.pad.default,
240241
exir_ops.edge.aten.constant_pad_nd.default,
242+
exir_ops.edge.aten.argmax.default,
241243
exir_ops.edge.aten.amax.default,
242244
exir_ops.edge.aten.amin.default,
243245
exir_ops.edge.aten.eye.default,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def _negative_checks(
336336
checks: list[OperatorSupportBase] = [RankCheck(reporter, MAX_RANK)]
337337

338338
if not tosa_spec.support_extension("int64"):
339-
checks.append(CheckInt64InputsAndOutputs(exported_program, reporter))
339+
checks.append(CheckInt64InputsAndOutputs(exported_program, reporter, tosa_spec))
340340

341341
checks.extend(_wrapped_additional_checks(additional_checks, reporter))
342342

@@ -683,7 +683,10 @@ class CheckInt64InputsAndOutputs(OperatorSupportBase):
683683
"""
684684

685685
def __init__(
686-
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
686+
self,
687+
exported_program: ExportedProgram,
688+
reporter: WhyNoPartitionReporter,
689+
tosa_spec: TosaSpecification,
687690
):
688691
"""Initialize the check with program context and reporter."""
689692
self.input_names = [
@@ -692,6 +695,7 @@ def __init__(
692695
if spec.kind == InputKind.USER_INPUT
693696
]
694697
self.reporter = reporter
698+
self.tosa_spec = tosa_spec
695699
self.int32_min = torch.iinfo(torch.int32).min
696700
self.int32_max = torch.iinfo(torch.int32).max
697701
super().__init__()
@@ -704,6 +708,104 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
704708
min_val, max_val = int(torch.min(data)), int(torch.max(data))
705709
return min_val >= self.int32_min and max_val <= self.int32_max
706710

711+
def has_rejected_int64_output(
712+
self, node: torch.fx.Node, tensor_list: Sequence[typing.Any]
713+
) -> bool:
714+
if node.target in (
715+
torch.ops.aten.argmax.default,
716+
exir_ops.edge.aten.argmax.default,
717+
):
718+
return not self._is_tosa_argmax_supported(node)
719+
return any(
720+
tensor.dtype == torch.int64
721+
for tensor in tensor_list
722+
if isinstance(tensor, FakeTensor)
723+
)
724+
725+
def _is_tosa_argmax_dtype_supported(
726+
self, node: torch.fx.Node, input_dtype: torch.dtype
727+
) -> bool:
728+
if input_dtype == torch.int8:
729+
if not self.tosa_spec.support_integer():
730+
self.reporter.report_reject(
731+
node, "TOSA ARGMAX requires PRO-INT for int8 input."
732+
)
733+
return False
734+
elif input_dtype == torch.int16:
735+
if not (
736+
self.tosa_spec.support_integer()
737+
and self.tosa_spec.support_extension("int16")
738+
):
739+
self.reporter.report_reject(
740+
node, "TOSA ARGMAX requires EXT-INT16 for int16 input."
741+
)
742+
return False
743+
elif input_dtype in (torch.float16, torch.float32):
744+
if not self.tosa_spec.support_float():
745+
self.reporter.report_reject(
746+
node, f"TOSA ARGMAX requires PRO-FP for {input_dtype} input."
747+
)
748+
return False
749+
elif input_dtype == torch.bfloat16:
750+
if not (
751+
self.tosa_spec.support_float()
752+
and self.tosa_spec.support_extension("bf16")
753+
):
754+
self.reporter.report_reject(
755+
node, "TOSA ARGMAX requires EXT-BF16 for bfloat16 input."
756+
)
757+
return False
758+
else:
759+
self.reporter.report_reject(
760+
node, f"TOSA ARGMAX does not support {input_dtype} input."
761+
)
762+
return False
763+
return True
764+
765+
def _is_tosa_argmax_supported(self, node: torch.fx.Node) -> bool:
766+
dim = node.kwargs.get("dim", node.args[1] if len(node.args) > 1 else None)
767+
if dim is None:
768+
self.reporter.report_reject(
769+
node, "TOSA ARGMAX requires an explicit reduction dimension."
770+
)
771+
return False
772+
if not isinstance(dim, int):
773+
self.reporter.report_reject(
774+
node, "TOSA ARGMAX requires a statically known reduction dimension."
775+
)
776+
return False
777+
778+
input_node = typing.cast(torch.fx.Node, node.args[0])
779+
input_tensor = get_first_fake_tensor(input_node)
780+
if not self._is_tosa_argmax_dtype_supported(node, input_tensor.dtype):
781+
return False
782+
783+
input_rank = len(input_tensor.shape)
784+
if input_rank == 0:
785+
self.reporter.report_reject(
786+
node, "TOSA ARGMAX requires an input with rank at least 1."
787+
)
788+
return False
789+
790+
axis = dim + input_rank if dim < 0 else dim
791+
if axis < 0 or axis >= input_rank:
792+
self.reporter.report_reject(
793+
node,
794+
f"TOSA ARGMAX axis must be in [0, {input_rank - 1}] but got {dim}.",
795+
)
796+
return False
797+
798+
keepdim = node.kwargs.get(
799+
"keepdim", node.args[2] if len(node.args) > 2 else False
800+
)
801+
if keepdim:
802+
self.reporter.report_reject(
803+
node, "TOSA ARGMAX does not support keepdim=True."
804+
)
805+
return False
806+
807+
return True
808+
707809
def _check_int64_input_nodes(self, node: torch.fx.Node) -> bool:
708810
"""Check if all int64 input nodes are constant and will be
709811
partitioned.
@@ -747,11 +849,7 @@ def is_node_supported(
747849
vals = node.meta["val"]
748850
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]
749851

750-
any_int64 = any(
751-
tensor.dtype == torch.int64
752-
for tensor in tensor_list
753-
if isinstance(tensor, FakeTensor)
754-
)
852+
any_int64 = self.has_rejected_int64_output(node, tensor_list)
755853
# Don't partition nodes with int64 output...
756854
if any_int64:
757855
# ... Except for constant ops that are directly cast to something non-int64.

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
op_sub,
4343
op_sum,
4444
op_to_dim_order_copy,
45+
op_tosa_argmax,
4546
op_tosa_avg_pool2d,
4647
op_tosa_avg_pool2d_adaptive,
4748
op_tosa_cast_to_block_scaled,
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch.fx
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.operators.operator_validation_utils import (
17+
validate_num_inputs,
18+
validate_valid_dtype,
19+
)
20+
from executorch.backends.arm.tosa.mapping import TosaArg
21+
22+
23+
@register_node_visitor
24+
class ArgMaxVisitor(NodeVisitor):
25+
target = "tosa.ARGMAX.default"
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
tosa_graph: Any,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
) -> None:
34+
validate_num_inputs(self.target, inputs, 2)
35+
validate_valid_dtype(
36+
self.target,
37+
inputs[0],
38+
[
39+
ts.DType.INT8,
40+
ts.DType.INT16,
41+
ts.DType.FP16,
42+
ts.DType.FP32,
43+
ts.DType.BF16,
44+
],
45+
self.tosa_spec,
46+
)
47+
validate_valid_dtype(self.target, output, ts.DType.INT32, self.tosa_spec)
48+
49+
axis = inputs[1].number
50+
if axis < 0:
51+
tensor = get_first_fake_tensor(node)
52+
axis += len(tensor.size())
53+
54+
attr = ts.TosaSerializerAttribute()
55+
attr.ArgMaxAttribute(axis, ts.NanPropagationMode.PROPAGATE)
56+
self._serialize_operator(
57+
node,
58+
tosa_graph,
59+
ts.Op.ARGMAX,
60+
[inputs[0].name],
61+
[output.name],
62+
attr,
63+
)

0 commit comments

Comments
 (0)