Skip to content

Commit b871398

Browse files
Arm backend: Add Amax-support for Ethos-U55 (#17372)
Additionally start deprecation of the unstable softmax decomp for u55 to be replaced by the now fully supported stable one. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent ccb13dd commit b871398

4 files changed

Lines changed: 50 additions & 6 deletions

File tree

backends/arm/common/pipeline_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -8,6 +8,8 @@
88
from enum import auto, Enum
99
from typing import Any
1010

11+
from executorch.exir._warnings import deprecated
12+
1113

1214
class SoftmaxDecompositionConfig(Enum):
1315
MASKED = auto()
@@ -24,7 +26,16 @@ class ArmPassPipelineConfig:
2426
softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED
2527
fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED
2628

29+
@deprecated(
30+
"The stable softmax decomposition is now supported by all arm targets and will be made default in a future release. Overwrite the default config using `compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())` to use the stable algorithm and avoid this error."
31+
)
2732
def disable_masked_softmax(self) -> None:
33+
"""
34+
.. warning::
35+
36+
The stable softmax decomposition is now supported by all arm targets and will be made default in a future release. Overwrite the default config using `compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())` to use the stable algorithm and avoid this error."
37+
"""
38+
2839
self.softmax = SoftmaxDecompositionConfig.UNSTABLE
2940

3041
def disable_fuse_duplicate_users(self) -> None:

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(self, reporter: WhyNoPartitionReporter):
8686
exir_ops.edge.aten.permute_copy.default,
8787
]
8888

89-
target_ops_i8 = tuple(TableOps.included_ops())
89+
target_ops_i8_i16 = (*TableOps.included_ops(), exir_ops.edge.aten.amax.default)
9090

9191
def is_node_supported( # noqa: C901
9292
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
@@ -117,7 +117,7 @@ def is_node_supported( # noqa: C901
117117
)
118118
return False
119119

120-
if node.target in self.target_ops_i8:
120+
if node.target in self.target_ops_i8_i16:
121121
if dtype not in (torch.int8, torch.int16):
122122
self.reporter.report_reject(
123123
node, f"Unsupported dtype {dtype} (Supports i8, i16)."
@@ -187,7 +187,6 @@ class EthosU55NotSupported(OperatorSupportBase):
187187
exir_ops.edge.aten.logical_or.default,
188188
exir_ops.edge.aten.logical_xor.default,
189189
exir_ops.edge.aten.logical_not.default,
190-
exir_ops.edge.aten.amax.default, # REDUCE_MAX
191190
exir_ops.edge.aten.amin.default, # REDUCE_MIN
192191
exir_ops.edge.aten.conv3d.default, # CONV3D
193192
exir_ops.edge.aten.conv3d.padding, # CONV3D (deprecated alias)

backends/arm/test/ops/test_amax.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from executorch.backends.arm.test import common
1212
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
1314
EthosU85PipelineINT,
1415
OpNotSupportedPipeline,
1516
TosaPipelineFP,
@@ -136,8 +137,19 @@ def test_amax_tosa_INT(test_data: Amax.input_t):
136137
pipeline.run()
137138

138139

140+
@common.parametrize("test_data", Amax.test_data)
141+
def test_amax_u55_INT(test_data: Amax.input_t):
142+
data, dim, keep_dims = test_data()
143+
pipeline = EthosU55PipelineINT[Amax.input_t](
144+
Amax(dim, keep_dims),
145+
data,
146+
Amax.aten_op,
147+
)
148+
pipeline.run()
149+
150+
139151
def test_amax_u55_INT_not_delegated():
140-
data, dim, keep_dims = Amax.test_data["rank_4_all_dim"]()
152+
data, dim, keep_dims = ((torch.ones([2, 2], dtype=torch.int32),), 1, False)
141153
pipeline = OpNotSupportedPipeline[Amax.input_t](
142154
Amax(dim, keep_dims),
143155
data,

backends/arm/test/ops/test_softmax.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
3+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

88
from typing import Tuple
99

1010
import torch
11+
from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig
1112
from executorch.backends.arm.test import common
1213
from executorch.backends.arm.test.tester.test_pipeline import (
1314
EthosU55PipelineINT,
@@ -70,7 +71,28 @@ def test_softmax_u55_INT(test_data):
7071
data,
7172
[],
7273
)
74+
pipeline.add_stage_after(
75+
"quantize", pipeline.tester.check_not, [aten_op, "torch.ops.aten.amax.default"]
76+
)
77+
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
78+
pipeline.run()
79+
80+
81+
@common.parametrize("test_data", Softmax.test_data)
82+
@common.XfailIfNoCorstone300
83+
def test_softmax_u55_INT_stable(test_data):
84+
data, dim = test_data()
85+
pipeline = EthosU55PipelineINT[input_t1](
86+
Softmax(dim),
87+
data,
88+
[],
89+
)
90+
# Override ArmPassPipelineConfig to disable the DecomposeSoftmaxUnstablePass
91+
pipeline.tester.compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())
7392
pipeline.add_stage_after("quantize", pipeline.tester.check_not, [aten_op])
93+
pipeline.add_stage_after(
94+
"quantize", pipeline.tester.check, ["torch.ops.aten.amax.default"]
95+
)
7496
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
7597
pipeline.run()
7698

0 commit comments

Comments
 (0)