Skip to content

Commit e638059

Browse files
committed
Arm backend: Decompose SDPA safe_softmax on U55
Make the U55 annotation and lowering pipelines handle SDPA aten._safe_softmax consistently by decomposing it instead of leaving it in the graph. Previously, transform_for_annotation_pipeline still used skip_safe_softmax for U55, which left aten._safe_softmax in annotated SDPA graphs and prevented delegation. Add a regression test that verifies U55 SDPA graphs no longer contain aten._safe_softmax after the annotation pipeline runs. Also warn when aten._safe_softmax is decomposed as regular softmax in the annotation pipeline, since this is only semantics-preserving when no row is fully masked at runtime. Remove the unstable softmax decomposition path and its remaining references now that the Arm backend uses the stable decomposition path. Update the related pipeline-config and softmax tests accordingly. Signed-off-by: per.held@arm.com Change-Id: I7a5147d5492974ead52ea92326352f7f4407bd67
1 parent e109ac8 commit e638059

File tree

9 files changed

+62
-179
lines changed

9 files changed

+62
-179
lines changed

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
from .decompose_sinh_pass import DecomposeSinhPass # noqa
8686
from .decompose_slice_scatter_pass import DecomposeSliceScatterPass # noqa
8787
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
88-
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
8988
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
9089
from .decompose_strided_slice_copy_pass import DecomposeStridedSliceCopyPass # noqa
9190
from .decompose_sum_pass import DecomposeSumPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@
8686
DecomposeSinhPass,
8787
DecomposeSliceScatterPass,
8888
DecomposeSoftmaxPass,
89-
DecomposeSoftmaxUnstablePass,
9089
DecomposeSqrtPass,
9190
DecomposeStridedSliceCopyPass,
9291
DecomposeSumPass,
@@ -196,12 +195,8 @@ def configure_skip_passes(
196195

197196
match config.softmax:
198197
case SoftmaxDecompositionConfig.MASKED:
199-
skip_set.add(DecomposeSoftmaxUnstablePass)
200-
case SoftmaxDecompositionConfig.UNSTABLE:
201-
skip_set.add(DecomposeSoftmaxPass)
202-
skip_set.add(DecomposeMaskedFillPass)
198+
pass
203199
case SoftmaxDecompositionConfig.STABLE:
204-
skip_set.add(DecomposeSoftmaxUnstablePass)
205200
skip_set.add(DecomposeMaskedFillPass)
206201

207202
if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED:
@@ -461,9 +456,7 @@ def _tosa_pipeline(
461456
ConvertMmToBmmPass(),
462457
DecomposeGluPass(),
463458
DecomposeDivPass(),
464-
# _safe_softmax results in a ReduceMax
465-
# which is not currently supported by TOSA in U55
466-
DecomposeSoftmaxPass(skip_safe_softmax=self.tosa_spec.is_U55_subset),
459+
DecomposeSoftmaxPass(),
467460
ConvertMinMaxPass(),
468461
DecomposeAnyPass(),
469462
DecomposeAdaptiveAvgPool2dPass(),
@@ -593,9 +586,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
593586
DecomposeSqrtPass(tfa_pass=True),
594587
DecomposeAdaptiveAvgPool2dPass(tfa_pass=True),
595588
DecomposeAvgPool2dPass(tfa_pass=True),
596-
DecomposeSoftmaxUnstablePass(tfa_pass=True),
597589
DecomposeSoftmaxPass(
598-
skip_safe_softmax=self.tosa_spec.is_U55_subset,
599590
tfa_pass=True,
600591
),
601592
ConvertMinMaxPass(tfa_pass=True),

backends/arm/_passes/decompose_softmax_pass.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
from typing import Set, Type
78

89
import torch
@@ -25,6 +26,8 @@
2526
)
2627
log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)
2728

29+
logger = logging.getLogger(__name__)
30+
2831

2932
def _get_logsoftmax_ops(op) -> tuple:
3033
"""Returns the (log_op, sub_op, amax_op, expo_op, sum_op, reciprocal_op),
@@ -78,6 +81,7 @@ class DecomposeSoftmaxPass(ArmPass):
7881
def __init__(self, skip_safe_softmax: bool = False, **kwargs):
7982
super().__init__(**kwargs)
8083
self._skip_safe_softmax = skip_safe_softmax
84+
self._warned_safe_softmax = False
8185

8286
def call_operator(self, op, args, kwargs, meta):
8387
if op not in torch_softmax + edge_softmax or not self.allowed_to_transform(
@@ -88,6 +92,18 @@ def call_operator(self, op, args, kwargs, meta):
8892
if self._skip_safe_softmax and op == torch.ops.aten._safe_softmax.default:
8993
return super().call_operator(op, args, kwargs, meta)
9094

95+
if (
96+
self.is_tfa_pass
97+
and op == torch.ops.aten._safe_softmax.default
98+
and not self._warned_safe_softmax
99+
):
100+
logger.warning(
101+
"aten._safe_softmax is being decomposed as regular softmax in "
102+
"the annotation pipeline; this is only semantics-preserving "
103+
"when no row is fully masked at runtime."
104+
)
105+
self._warned_safe_softmax = True
106+
91107
log_op, sub_op, max_op, exp_op, sum_op, reciprocal_op, mul_op = (
92108
_get_logsoftmax_ops(op)
93109
)

backends/arm/_passes/decompose_softmax_unstable_pass.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

backends/arm/common/arm_compile_spec.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from dataclasses import dataclass, field
1717
from enum import Enum
1818

19-
from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig
19+
from executorch.backends.arm.common.pipeline_config import (
20+
ArmPassPipelineConfig,
21+
SoftmaxDecompositionConfig,
22+
)
2023
from executorch.backends.arm.tosa import TosaSpecification
2124
from executorch.exir._warnings import deprecated
2225

@@ -250,7 +253,10 @@ def set_pass_pipeline_config(self, config: ArmPassPipelineConfig) -> None:
250253
def _create_default_pipeline_config(self) -> ArmPassPipelineConfig:
251254
config = ArmPassPipelineConfig()
252255
if self.tosa_spec.is_U55_subset:
253-
config.disable_masked_softmax()
256+
# Keep U55 on STABLE instead of the generic MASKED default:
257+
# MASKED also enables masked_fill decomposition, which lowers to
258+
# where/full_like and is not a good default fit for U55.
259+
config.softmax = SoftmaxDecompositionConfig.STABLE
254260
return config
255261

256262
def _get_intermediate_path(self) -> str | None:

backends/arm/common/pipeline_config.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@
88
from enum import auto, Enum
99
from typing import Any
1010

11-
from executorch.exir._warnings import deprecated
12-
1311

1412
class SoftmaxDecompositionConfig(Enum):
1513
MASKED = auto() # Stable softmax + masked fill decomposition
16-
UNSTABLE = auto() # Unstable softmax, no masked fill decomposition
1714
STABLE = auto() # Stable softmax, no masked fill decomposition
1815

1916

@@ -27,18 +24,6 @@ class ArmPassPipelineConfig:
2724
softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED
2825
fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED
2926

30-
@deprecated(
31-
"The stable softmax decomposition is now supported by all arm targets and will be made default in a future release. Overwrite the default config using `compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())` to use the stable algorithm and avoid this error."
32-
)
33-
def disable_masked_softmax(self) -> None:
34-
"""
35-
.. warning::
36-
37-
The stable softmax decomposition is now supported by all arm targets and will be made default in a future release. Overwrite the default config using `compile_spec.set_pass_pipeline_config(ArmPassPipelineConfig())` to use the stable algorithm and avoid this error."
38-
"""
39-
40-
self.softmax = SoftmaxDecompositionConfig.STABLE
41-
4227
def disable_fuse_duplicate_users(self) -> None:
4328
self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED
4429

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
from executorch.backends.arm._passes import (
77
DecomposeMaskedFillPass,
88
DecomposeSoftmaxPass,
9-
DecomposeSoftmaxUnstablePass,
109
FuseDuplicateUsersPass,
1110
)
1211
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
@@ -25,7 +24,7 @@ def test_pipeline_config_override_outside_compile_spec_no_target():
2524
default_manager = ArmPassManager(compile_spec)
2625
default_skip_passes = default_manager._skip_pass_types
2726
assert FuseDuplicateUsersPass not in default_skip_passes
28-
assert DecomposeSoftmaxUnstablePass in default_skip_passes
27+
assert DecomposeSoftmaxPass not in default_skip_passes
2928

3029
override_compile_spec = TosaCompileSpec(
3130
TosaSpecification.create_from_string("TOSA-1.00+INT")
@@ -37,10 +36,10 @@ def test_pipeline_config_override_outside_compile_spec_no_target():
3736
skip_passes = override_manager._skip_pass_types
3837

3938
assert FuseDuplicateUsersPass in skip_passes
40-
assert DecomposeSoftmaxUnstablePass in skip_passes
39+
assert DecomposeSoftmaxPass not in skip_passes
4140

4241

43-
def test_softmax_config_masked():
42+
def test_softmax_config_masked_no_target():
4443
"""Test MASKED config: stable softmax, masked fill decomposition enabled."""
4544
compile_spec = TosaCompileSpec(
4645
TosaSpecification.create_from_string("TOSA-1.00+INT")
@@ -50,31 +49,13 @@ def test_softmax_config_masked():
5049
manager = ArmPassManager(compile_spec)
5150
skip_passes = manager._skip_pass_types
5251

53-
# MASKED: skip unstable softmax, use stable softmax
54-
assert DecomposeSoftmaxUnstablePass in skip_passes
52+
# MASKED: use stable softmax
5553
assert DecomposeSoftmaxPass not in skip_passes
5654
# MASKED: masked fill decomposition is enabled (not skipped)
5755
assert DecomposeMaskedFillPass not in skip_passes
5856

5957

60-
def test_softmax_config_unstable():
61-
"""Test UNSTABLE config: unstable softmax, no masked fill decomposition."""
62-
compile_spec = TosaCompileSpec(
63-
TosaSpecification.create_from_string("TOSA-1.00+INT")
64-
)
65-
config = ArmPassPipelineConfig(softmax=SoftmaxDecompositionConfig.UNSTABLE)
66-
compile_spec.set_pass_pipeline_config(config)
67-
manager = ArmPassManager(compile_spec)
68-
skip_passes = manager._skip_pass_types
69-
70-
# UNSTABLE: skip stable softmax, use unstable softmax
71-
assert DecomposeSoftmaxPass in skip_passes
72-
assert DecomposeSoftmaxUnstablePass not in skip_passes
73-
# UNSTABLE: masked fill decomposition is disabled (skipped)
74-
assert DecomposeMaskedFillPass in skip_passes
75-
76-
77-
def test_softmax_config_stable():
58+
def test_softmax_config_stable_no_target():
7859
"""Test STABLE config: stable softmax, no masked fill decomposition."""
7960
compile_spec = TosaCompileSpec(
8061
TosaSpecification.create_from_string("TOSA-1.00+INT")
@@ -84,8 +65,7 @@ def test_softmax_config_stable():
8465
manager = ArmPassManager(compile_spec)
8566
skip_passes = manager._skip_pass_types
8667

87-
# STABLE: skip unstable softmax, use stable softmax
88-
assert DecomposeSoftmaxUnstablePass in skip_passes
68+
# STABLE: use stable softmax
8969
assert DecomposeSoftmaxPass not in skip_passes
9070
# STABLE: masked fill decomposition is disabled (skipped)
9171
assert DecomposeMaskedFillPass in skip_passes

backends/arm/test/ops/test_sdpa.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import torch
1010

11+
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
12+
from executorch.backends.arm.ethosu import EthosUCompileSpec
1113
from executorch.backends.arm.test import common
1214
from executorch.backends.arm.test.tester.test_pipeline import (
1315
EthosU55PipelineINT,
@@ -16,6 +18,7 @@
1618
TosaPipelineINT,
1719
VgfPipeline,
1820
)
21+
from torch.export import export
1922

2023

2124
class SDPA(torch.nn.Module):
@@ -106,13 +109,6 @@ def test_sdpa_vgf_quant(test_case: test_case_t):
106109

107110
@common.parametrize("test_case", test_suite)
108111
def test_sdpa_u55_INT(test_case: test_case_t):
109-
"""Verify SDPA compiles on U55.
110-
111-
_safe_softmax from SDPA is skipped by DecomposeSoftmaxPass
112-
(skip_safe_softmax=True for U55) and runs on CPU, avoiding REDUCE_MAX which
113-
fails Vela compilation.
114-
115-
"""
116112
model, test_input = test_case()
117113
pipeline = EthosU55PipelineINT[input_t](model, test_input, [], [])
118114
pipeline.pop_stage("check.quant_nodes")
@@ -122,14 +118,34 @@ def test_sdpa_u55_INT(test_case: test_case_t):
122118

123119

124120
@common.parametrize("test_case", test_suite)
125-
@common.XfailIfNoCorstone320
126-
def test_sdpa_u85_INT(test_case: test_case_t):
127-
"""Verify SDPA compiles on U85.
121+
def test_sdpa_u55_INT_annotation_pipeline_decomposes_safe_softmax(
122+
test_case: test_case_t,
123+
):
124+
"""Verify the U55 annotation pipeline decomposes SDPA _safe_softmax.
128125
129-
_safe_softmax is decomposed with stable softmax (including amax/REDUCE_MAX)
130-
which is supported on U85.
126+
U55 now matches U85 and VGF: the annotation pipeline lowers
127+
_safe_softmax to the stable softmax primitive sequence instead of leaving
128+
it in the graph for partitioning.
131129
132130
"""
131+
model, test_input = test_case()
132+
exported_program = export(model, test_input)
133+
graph_module = ArmPassManager(
134+
EthosUCompileSpec("ethos-u55-128")
135+
).transform_for_annotation_pipeline(exported_program.graph_module)
136+
137+
softmax_targets = {
138+
str(node.target)
139+
for node in graph_module.graph.nodes
140+
if node.op == "call_function" and "softmax" in str(node.target)
141+
}
142+
143+
assert "aten._safe_softmax.default" not in softmax_targets
144+
145+
146+
@common.parametrize("test_case", test_suite)
147+
@common.XfailIfNoCorstone320
148+
def test_sdpa_u85_INT(test_case: test_case_t):
133149
model, test_input = test_case()
134150
pipeline = EthosU85PipelineINT[input_t](model, test_input, [], [])
135151
pipeline.pop_stage("check.quant_nodes")

0 commit comments

Comments
 (0)