Skip to content

Commit ee61747

Browse files
Arm backend: Improve ELU support (pytorch#19694)
Arm backend: Improve ELU support - Adds support for different scale and input_scale values - Adds support for related SELU and CELU operators, corresponding to ELU with particular scales. - Use initial float values for alpha, scale and input_scale rather than rounded values. --------- Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent ea412d8 commit ee61747

11 files changed

Lines changed: 326 additions & 100 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from .decompose_div_pass import DecomposeDivPass # noqa
4444
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
4545
from .decompose_einsum_pass import DecomposeEinsumPass # noqa
46-
from .decompose_elu_pass import DecomposeEluPass # noqa
46+
from .decompose_elu_pass import ConvertEluFamilyToEluPass, DecomposeEluPass # noqa
4747
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4848
from .decompose_erfinv_pass import DecomposeErfinvPass # noqa
4949
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ConstantFoldingPass,
2121
ControlFlowConstInlinePass,
2222
Conv1dUnsqueezePass,
23+
ConvertEluFamilyToEluPass,
2324
ConvertELUParamsPass,
2425
ConvertExpandCopyToRepeatPass,
2526
ConvertFullLikeToFullPass,
@@ -403,6 +404,7 @@ def _tosa_pipeline(
403404
DecomposeLayerNormPass(),
404405
DecomposeVarPass(),
405406
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
407+
ConvertEluFamilyToEluPass(),
406408
ConvertELUParamsPass(),
407409
ControlFlowConstInlinePass(),
408410
NormalizeWhileInitialArgsPass(use_exir_clone=True),
@@ -607,6 +609,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
607609
RewriteInplaceArithmeticPass(tfa_pass=True),
608610
DecomposeAddSubAlphaPass(tfa_pass=True),
609611
DecomposeLeakyReLUPass(tfa_pass=True),
612+
ConvertEluFamilyToEluPass(tfa_pass=True),
610613
DecomposeGroupNormPass(tfa_pass=True),
611614
DecomposeLayerNormPass(tfa_pass=True),
612615
DecomposeVarPass(tfa_pass=True),

backends/arm/_passes/convert_elu_params.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@
1414

1515

1616
class ConvertELUParamsPass(ArmPass):
17-
"""Pass to convert the input_scale kwarg of ELU operator from float to int.
17+
"""The int8 ELU operator crashes when the alpha, scale or input scale
18+
parameters are not integers.
1819
19-
It has been set to 2 as the outputs seem to stay the same regardless of what
20-
the value of input_scale is, as long as that value is not 1.
20+
This pass temporarily converts quantized ELU parameters to int and stores
21+
the original float values in the meta dict to be able to recover them later.
2122
2223
"""
2324

24-
_passes_required_after: Set[Type[ExportPass]] = set()
25+
@property
26+
def _passes_required_after(self) -> Set[Type[ExportPass]]:
27+
# Lazy import to avoid circular dependency between passes
28+
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
29+
30+
return {InsertTableOpsPass}
2531

2632
def call(self, graph_module: torch.fx.GraphModule):
2733
modified_graph = False
@@ -36,29 +42,45 @@ def call(self, graph_module: torch.fx.GraphModule):
3642
)
3743
if not is_quantized or not self.allowed_to_transform(node.meta):
3844
continue
45+
3946
with graph.inserting_after(node):
4047
replace_node = create_node(
4148
graph, exir_ops.edge.aten.elu.default, from_node=node
4249
)
43-
old_args = list(node.args)
4450

45-
alpha = old_args[1] if len(old_args) > 1 else 1.0
46-
scale = 1.0
47-
input_scale = 2.0
51+
old_args = list(node.args)
52+
alpha = (
53+
old_args[1] if len(old_args) > 1 else node.kwargs.get("alpha", 1.0)
54+
)
55+
scale = (
56+
old_args[2] if len(old_args) > 2 else node.kwargs.get("scale", 1.0)
57+
)
58+
input_scale = (
59+
old_args[3]
60+
if len(old_args) > 3
61+
else node.kwargs.get("input_scale", 1.0)
62+
)
4863

4964
replace_node.args = (old_args[0],)
5065

66+
# Set placeholder int values
5167
updated_kwargs = dict(node.kwargs)
52-
updated_kwargs["alpha"] = int(alpha)
53-
updated_kwargs["scale"] = int(scale)
54-
updated_kwargs["input_scale"] = int(input_scale)
55-
68+
updated_kwargs["alpha"] = 1
69+
updated_kwargs["scale"] = 1
70+
updated_kwargs["input_scale"] = (
71+
2 # Keep input_scale away from 1 to avoid fake execution type checks.
72+
)
5673
replace_node.kwargs = updated_kwargs
5774

75+
# Save correct parameters
76+
replace_node.meta["float_alpha"] = alpha
77+
replace_node.meta["float_scale"] = scale
78+
replace_node.meta["float_input_scale"] = input_scale
79+
5880
node.replace_all_uses_with(replace_node)
5981
graph.erase_node(node)
60-
6182
modified_graph = True
83+
6284
if modified_graph:
6385
graph_module.recompile()
6486
graph_module = super().call(graph_module).graph_module

backends/arm/_passes/decompose_elu_pass.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,22 @@
55

66
from typing import Set, Type
77

8+
import torch
89
from executorch.backends.arm._passes import ArmPass
910
from executorch.exir.dialects._ops import ops as exir_ops
1011
from executorch.exir.pass_base import ExportPass
1112

1213
edge_elu_ops = (exir_ops.edge.aten.elu.default,)
14+
edge_selu_ops = (exir_ops.edge.aten.selu.default,)
15+
edge_celu_ops = (exir_ops.edge.aten.celu.default,)
16+
edge_elu_family_ops = edge_elu_ops + edge_selu_ops + edge_celu_ops
17+
torch_selu_ops = (torch.ops.aten.selu.default,)
18+
torch_celu_ops = (torch.ops.aten.celu.default,)
19+
selu_ops = edge_selu_ops + torch_selu_ops
20+
celu_ops = edge_celu_ops + torch_celu_ops
21+
22+
SELU_ALPHA = 1.6732632423543772
23+
SELU_SCALE = 1.0507009873554805
1324

1425

1526
def get_elu_decomposition(op) -> tuple:
@@ -29,7 +40,7 @@ def get_elu_decomposition(op) -> tuple:
2940
3041
"""
3142

32-
if op in edge_elu_ops:
43+
if op in edge_elu_family_ops:
3344
return (
3445
exir_ops.edge.aten.expm1.default,
3546
exir_ops.edge.aten.ge.Scalar,
@@ -40,15 +51,64 @@ def get_elu_decomposition(op) -> tuple:
4051
raise RuntimeError(f"Can't get elu decomposition for op {op}")
4152

4253

54+
def _get_elu_parameter(args, kwargs, index, name):
55+
if len(args) > index:
56+
return args[index]
57+
58+
return kwargs.get(name, 1.0)
59+
60+
61+
def _get_elu_parameters(op, args, kwargs):
62+
if op in selu_ops:
63+
return SELU_ALPHA, SELU_SCALE, 1.0
64+
if op in celu_ops:
65+
alpha = _get_elu_parameter(args, kwargs, 1, "alpha")
66+
return alpha, 1.0, 1.0 / alpha
67+
68+
alpha = _get_elu_parameter(args, kwargs, 1, "alpha")
69+
scale = _get_elu_parameter(args, kwargs, 2, "scale")
70+
input_scale = _get_elu_parameter(args, kwargs, 3, "input_scale")
71+
return alpha, scale, input_scale
72+
73+
74+
class ConvertEluFamilyToEluPass(ArmPass):
75+
"""Convert SELU/CELU ops to equivalent parameterized ELU ops."""
76+
77+
_passes_required_after: Set[Type[ExportPass]] = set()
78+
79+
def call_operator(self, op, args, kwargs, meta):
80+
if op not in selu_ops + celu_ops or not self.allowed_to_transform(meta):
81+
return super().call_operator(op, args, kwargs, meta, updated=False)
82+
83+
input_ = args[0]
84+
alpha, scale, input_scale = _get_elu_parameters(op, args, kwargs)
85+
elu_op = (
86+
torch.ops.aten.elu.default
87+
if op in torch_selu_ops + torch_celu_ops
88+
else exir_ops.edge.aten.elu.default
89+
)
90+
return super().call_operator(
91+
elu_op,
92+
(input_, alpha, scale, input_scale),
93+
{},
94+
meta,
95+
updated=True,
96+
)
97+
98+
4399
class DecomposeEluPass(ArmPass):
44100
"""A transformation pass that decomposes unsupported 'aten.elu' operations
45101
into a combination of supported TOSA-equivalent operations.
46102
47103
Since TOSA does not provide a native ELU operator, this pass rewrites:
48-
elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x)
104+
elu(x) → scale * where(
105+
greater_or_eq(x, 0), x, alpha * expm1(input_scale * x)
106+
)
49107
50108
Supported input ops:
51-
- exir_ops.edge.aten.elu.Tensor(x)
109+
- exir_ops.edge.aten.elu.default
110+
- exir_ops.edge.aten.selu.default
111+
- exir_ops.edge.aten.celu.default
52112
53113
These are replaced with:
54114
- exir_ops.edge.aten.expm1.default
@@ -61,7 +121,7 @@ class DecomposeEluPass(ArmPass):
61121
_passes_required_after: Set[Type[ExportPass]] = set()
62122

63123
def call_operator(self, op, args, kwargs, meta):
64-
if op not in edge_elu_ops:
124+
if op not in edge_elu_family_ops:
65125
return super().call_operator(op, args, kwargs, meta, updated=False)
66126

67127
if self._is_quantized_meta(meta):
@@ -76,11 +136,11 @@ def call_operator(self, op, args, kwargs, meta):
76136
) = get_elu_decomposition(op)
77137

78138
input = args[0]
79-
alpha = args[1] if len(args) > 1 else 1.0
139+
alpha, scale, input_scale = _get_elu_parameters(op, args, kwargs)
80140

81141
if alpha == 0:
82142
relu_op = exir_ops.edge.aten.clamp.default
83-
return super().call_operator(
143+
relu_node = super().call_operator(
84144
relu_op,
85145
(
86146
input,
@@ -90,14 +150,35 @@ def call_operator(self, op, args, kwargs, meta):
90150
meta,
91151
updated=True,
92152
)
153+
if scale == 1:
154+
return relu_node
93155

94-
expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True)
156+
return super().call_operator(
157+
mul_op, (relu_node, scale), {}, meta, updated=True
158+
)
159+
160+
expm1_input = input
161+
if input_scale != 1:
162+
expm1_input = super().call_operator(
163+
mul_op, (input, input_scale), {}, meta, updated=True
164+
)
165+
expm1_node = super().call_operator(
166+
expm1_op, (expm1_input,), {}, meta, updated=True
167+
)
95168
mul_node = super().call_operator(
96169
mul_op, (expm1_node, alpha), {}, meta, updated=True
97170
)
98171
ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True)
172+
positive_node = input
173+
if scale != 1:
174+
positive_node = super().call_operator(
175+
mul_op, (input, scale), {}, meta, updated=True
176+
)
177+
mul_node = super().call_operator(
178+
mul_op, (mul_node, scale), {}, meta, updated=True
179+
)
99180
where_node = super().call_operator(
100-
where_op, (ge_node, input, mul_node), {}, meta, updated=True
181+
where_op, (ge_node, positive_node, mul_node), {}, meta, updated=True
101182
)
102183

103184
return where_node

backends/arm/_passes/insert_table_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ def __getitem__(self, node: Node):
100100
x, approximate=approximate
101101
).flatten()
102102
case exir_ops.edge.aten.elu.default:
103-
input_alpha = cast(int, node.kwargs["alpha"])
104-
return lambda x: torch.nn.functional.elu(
105-
x, alpha=input_alpha
103+
input_alpha = cast(float, node.meta["float_alpha"])
104+
input_scale = cast(float, node.meta.get("float_input_scale", 1.0))
105+
scale = cast(float, node.meta.get("float_scale", 1.0))
106+
return lambda x: torch.ops.aten.elu.default(
107+
x, input_alpha, scale, input_scale
106108
).flatten()
107109
case exir_ops.edge.aten.remainder.Scalar:
108110
divisor = cast(float | int, node.args[1])

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@
121121
exir_ops.edge.aten.cosh.default,
122122
exir_ops.edge.aten.acos.default,
123123
exir_ops.edge.aten.elu.default,
124+
exir_ops.edge.aten.selu.default,
125+
exir_ops.edge.aten.celu.default,
124126
exir_ops.edge.aten.bitwise_not.default,
125127
exir_ops.edge.aten.copy.default,
126128
exir_ops.edge.aten.tan.default,
@@ -244,6 +246,8 @@
244246
exir_ops.edge.aten.logit.default,
245247
exir_ops.edge.aten.acos.default,
246248
exir_ops.edge.aten.elu.default,
249+
exir_ops.edge.aten.selu.default,
250+
exir_ops.edge.aten.celu.default,
247251
exir_ops.edge.aten.copy.default,
248252
exir_ops.edge.aten.floor_divide.default,
249253
exir_ops.edge.aten.tan.default,

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ def _match_pattern(
481481
torch.ops.aten.exp.default,
482482
torch.ops.aten.expm1.default,
483483
torch.ops.aten.elu.default,
484+
torch.ops.aten.selu.default,
485+
torch.ops.aten.celu.default,
484486
torch.ops.aten.floor.default,
485487
torch.ops.aten.log.default,
486488
torch.ops.aten.reciprocal.default,

backends/arm/scripts/docgen/ethos-u/ethos-u-getting-started-tutorial.md.in

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ In this tutorial you will learn how to export a simple PyTorch model for the Exe
2020
```{tip}
2121
If you are already familiar with this delegate, you may want to jump directly to the examples:
2222
* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm)
23-
* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
23+
* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
2424
```
2525

2626
This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on Arm&reg; Ethos&trade;-U targets. It is based on `ethos_u_minimal_example.ipynb`, provided in Arm’s examples folder.
@@ -69,9 +69,10 @@ The example below shows how to quantize a model consisting of a single addition,
6969
$MINIMAL_EXAMPLE
7070

7171
```{tip}
72-
For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
72+
For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
7373
To produce a pte file equivalent to the one above, run
74-
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`
74+
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=ethos_u_minimal_example.pte`.
75+
For production use, you should instead use the stable Python API shown above.
7576
```
7677

7778
### Runtime:

backends/arm/scripts/docgen/vgf/vgf-getting-started-tutorial.md.in

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ You may encounter some rough edges and features which may be documented or plann
2626
```{tip}
2727
If you are already familiar with this delegate, you may want to jump directly to the examples:
2828
* [Examples in the ExecuTorch repository](https://github.com/pytorch/executorch/tree/main/examples/arm)
29-
* [A commandline compiler for example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
29+
* [A commandline compiler for quick tests and example models](https://github.com/pytorch/executorch/blob/main/backends/arm/scripts/aot_arm_compiler.py)
3030
```
3131

3232
This tutorial serves as an introduction to using ExecuTorch to deploy PyTorch models on VGF targets. The tutorial is based on `vgf_minimal_example.ipyb`, provided in Arm's example folder.
@@ -78,9 +78,10 @@ The example below shows how to quantize a model consisting of a single addition,
7878
$MINIMAL_EXAMPLE
7979

8080
```{tip}
81-
For a quick start, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
81+
For a quick test, you can use the script `backends/arm/scripts/aot_arm_compiler.py` to produce the pte.
8282
To produce a pte file equivalent to the one above, run
83-
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf`
83+
`python -m backends.arm.scripts.aot_arm_compiler --model_name=add --delegate --quantize --output=simple_example.pte --target=vgf`.
84+
For production use, you should instead use the stable Python API shown above.
8485
```
8586

8687
## Runtime

0 commit comments

Comments
 (0)