Skip to content

Commit fa81941

Browse files
Arm backend: Support quantization of torch.einsum (#18539)
- Refactor GetDecompositionPass to expose small override hooks for trace inputs, placeholder mapping, and output extraction. - Implement DecomposeEinsumPass as a GetDecompositionPass subclass to handle aten.einsum.default's (equation, [operands]) signature. Change-Id: I0546572f44903e400dfd75ea067ed2785e2a326e cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent 3d63ad8 commit fa81941

6 files changed

Lines changed: 176 additions & 12 deletions

File tree

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
4444
from .decompose_div_pass import DecomposeDivPass # noqa
4545
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
46+
from .decompose_einsum_pass import DecomposeEinsumPass # noqa
4647
from .decompose_elu_pass import DecomposeEluPass # noqa
4748
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4849
from .decompose_erfinv_pass import DecomposeErfinvPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DecomposeCumsumPass,
5252
DecomposeDivPass,
5353
DecomposeDivTensorModePass,
54+
DecomposeEinsumPass,
5455
DecomposeEluPass,
5556
DecomposeEmbeddingPass,
5657
DecomposeErfinvPass,
@@ -560,6 +561,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
560561
DecomposeFloorDividePass(tfa_pass=True),
561562
DecomposeDivTensorModePass(tfa_pass=True),
562563
DecomposeWhereScalarOtherPass(tfa_pass=True),
564+
DecomposeEinsumPass(tfa_pass=True),
563565
RewriteInplaceArithmeticPass(tfa_pass=True),
564566
DecomposeAddSubAlphaPass(tfa_pass=True),
565567
DecomposeLeakyReLUPass(tfa_pass=True),
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.get_decomposition_pass import GetDecompositionPass
8+
9+
10+
class DecomposeEinsumPass(GetDecompositionPass):
11+
"""Decomposes aten.einsum.default into more primitive ops.
12+
13+
This pass is intended to be called in transform_for_annotation to prepare
14+
the graph for quantization. Einsum is not annotated directly by the Arm
15+
quantizer, but the decomposed ops are.
16+
17+
"""
18+
19+
targeted_ops = [torch.ops.aten.einsum.default]
20+
21+
def _get_input_tensors(self, node: torch.fx.Node) -> list:
22+
"""Override the base hook because aten.einsum.default takes (equation,
23+
[operands]), which cannot be handled by the generic one-arg-per-input
24+
logic.
25+
"""
26+
equation, operands = node.args # type: ignore[union-attr]
27+
fake_operands = [operand.meta["val"] for operand in operands] # type: ignore[union-attr]
28+
return [equation, fake_operands]
29+
30+
def _get_placeholder_map(
31+
self,
32+
node: torch.fx.Node,
33+
decomposed_module: torch.fx.GraphModule,
34+
) -> dict[str, torch.fx.Node]:
35+
"""Override the base hook because einsum does not trace placeholders
36+
one-to-one with node.args.
37+
38+
The traced graph includes arg0_1 for the equation string and arg1_i for
39+
each tensor inside the operand list, so we must skip the equation
40+
placeholder, which is not an original FX tensor node, and map each
41+
operand placeholder back to the corresponding original FX node.
42+
43+
"""
44+
_, operands = node.args
45+
name_to_input_tensor_map = {}
46+
47+
for decomposed_node in decomposed_module.graph.nodes:
48+
if decomposed_node.op != "placeholder":
49+
continue
50+
if decomposed_node.name == "arg0_1":
51+
continue
52+
if not decomposed_node.name.startswith("arg1_"):
53+
raise RuntimeError(
54+
f"Unexpected einsum placeholder name {decomposed_node.name!r}."
55+
)
56+
57+
operand_idx = int(decomposed_node.name.split("_")[1]) - 1
58+
name_to_input_tensor_map[decomposed_node.name] = operands[operand_idx] # type: ignore[index]
59+
60+
return name_to_input_tensor_map # type: ignore[return-value]
61+
62+
def _get_output_node(self, output_node: torch.fx.Node) -> torch.fx.Node:
63+
"""Return the traced value node for einsum graphs that emit
64+
output([node]).
65+
"""
66+
return output_node.args[0][0] # type: ignore[index, return-value]

backends/arm/_passes/get_decomposition_pass.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,31 @@ def __init__(self, tfa_pass=False, *args, **kwargs):
3434
def _skip_pass(self, input_tensors: list) -> bool:
3535
return False
3636

37+
def _get_input_tensors(self, node: torch.fx.Node) -> list:
38+
input_tensors = []
39+
for arg in node.args:
40+
if hasattr(arg, "meta"):
41+
input_tensors.append(arg.meta["val"]) # type: ignore[union-attr]
42+
elif isinstance(arg, int):
43+
input_tensors.append(arg)
44+
return input_tensors
45+
46+
def _get_placeholder_map(
47+
self,
48+
node: torch.fx.Node,
49+
decomposed_module: torch.fx.GraphModule,
50+
) -> dict[str, torch.fx.Node]:
51+
# Keep decomposed_module in the hook signature so subclasses can inspect
52+
# traced placeholder structure when the mapping is not one-to-one.
53+
name_to_input_tensor_map = {}
54+
for i, arg in enumerate(node.args):
55+
name_to_input_tensor_map[f"arg{i}_1"] = arg
56+
return name_to_input_tensor_map # type: ignore[return-value]
57+
58+
def _get_output_node(self, output_node: torch.fx.Node) -> torch.fx.Node:
59+
"""Return the traced value node for graphs that emit output(node)."""
60+
return output_node.args[0] # type: ignore[return-value]
61+
3762
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
3863
modified = False
3964
for node in graph_module.graph.nodes:
@@ -44,13 +69,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
4469
):
4570
continue
4671

47-
input_tensors = []
48-
for arg in node.args:
49-
if hasattr(arg, "meta"):
50-
input_tensors.append(arg.meta["val"])
51-
52-
elif isinstance(arg, int):
53-
input_tensors.append(arg)
72+
input_tensors = self._get_input_tensors(node)
5473

5574
if self._skip_pass(input_tensors):
5675
continue
@@ -70,22 +89,26 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
7089
)(*input_tensors)
7190

7291
with graph_module.graph.inserting_before(node):
73-
name_to_input_tensor_map = {}
74-
for i, arg in enumerate(node.args):
75-
name_to_input_tensor_map[f"arg{i}_1"] = arg
92+
name_to_input_tensor_map = self._get_placeholder_map(
93+
node, decomposed_module
94+
)
7695

7796
decomposed_node_to_subgraph_node = {}
7897
last_decomposed_node = None
7998
# Create a mapping from input nodes in decomposed module to original nodes.
8099
# In decomposed module, there are only input tensors for placeholder op.
81100
for decomposed_node in decomposed_module.graph.nodes:
82101
if decomposed_node.op == "placeholder":
102+
# Some ops, such as einsum, trace extra placeholders that do
103+
# not map back to original graph tensor inputs.
104+
if decomposed_node.name not in name_to_input_tensor_map:
105+
continue
83106
decomposed_node_to_subgraph_node[decomposed_node] = (
84107
name_to_input_tensor_map[decomposed_node.name]
85108
)
86109

87110
if decomposed_node.op == "output":
88-
last_decomposed_node = decomposed_node.args[0]
111+
last_decomposed_node = self._get_output_node(decomposed_node)
89112

90113
# Copy node from decompose graph module
91114
for decomposed_node in decomposed_module.graph.nodes:

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class TestSD3Transformer2DModel:
4242
ops_after_partitioner_INT = {
4343
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
4444
"torch.ops.higher_order.executorch_call_delegate": 2,
45-
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
4645
}
4746

4847
ops_after_partitioner_vgf_quantize = {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes import DecomposeEinsumPass
10+
from executorch.backends.arm.quantizer import (
11+
get_symmetric_quantization_config,
12+
TOSAQuantizer,
13+
)
14+
from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline
15+
from executorch.backends.arm.tosa import TosaSpecification
16+
from torch.export import export
17+
18+
19+
input_t = Tuple[torch.Tensor]
20+
21+
22+
class EinsumPermuteModule(torch.nn.Module):
23+
def forward(self, x: torch.Tensor) -> torch.Tensor:
24+
return torch.einsum("nhwpqc->nchpwq", x)
25+
26+
@staticmethod
27+
def get_inputs() -> input_t:
28+
return (torch.randn(2, 4, 16, 1, 16, 1),)
29+
30+
31+
def _get_int8_quantizer() -> TOSAQuantizer:
32+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
33+
quantizer.set_global(get_symmetric_quantization_config())
34+
return quantizer
35+
36+
37+
def test_decompose_einsum_no_target_rewrites_export_graph() -> None:
38+
module = EinsumPermuteModule().eval()
39+
exported_program = export(module, module.get_inputs())
40+
41+
before_targets = [
42+
str(node.target)
43+
for node in exported_program.graph_module.graph.nodes
44+
if node.op == "call_function"
45+
]
46+
assert before_targets == ["aten.einsum.default"]
47+
48+
pass_result = DecomposeEinsumPass().call(exported_program.graph_module)
49+
50+
after_targets = [
51+
str(node.target)
52+
for node in pass_result.graph_module.graph.nodes
53+
if node.op == "call_function"
54+
]
55+
assert "aten.einsum.default" not in after_targets
56+
assert after_targets == ["aten.permute.default"]
57+
58+
59+
def test_decompose_einsum_tosa_INT_quantizes_after_transform_for_annotation() -> None:
60+
module = EinsumPermuteModule().eval()
61+
quantization_annotations = {
62+
"aten.permute.default": {
63+
get_symmetric_quantization_config().output_activation: 1
64+
}
65+
}
66+
67+
pipeline = QuantizationPipeline[input_t](
68+
module,
69+
module.get_inputs(),
70+
quantizer=_get_int8_quantizer(),
71+
qspecs=quantization_annotations, # type: ignore[arg-type]
72+
)
73+
pipeline.run()

0 commit comments

Comments
 (0)