Skip to content

Commit e44d2b1

Browse files
Arm backend: Consolidate simple operator visitors
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com> Change-Id: I23339f808f1074adea1fafddf90110c04fc5695f
1 parent 7edb46d commit e44d2b1

23 files changed

Lines changed: 432 additions & 790 deletions

backends/arm/operators/TARGETS

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
33

44
runtime.python_library(
55
name = "node_visitor",
6-
srcs = ["node_visitor.py"],
6+
srcs = [
7+
"node_visitor.py",
8+
"simple_node_visitor.py",
9+
],
710
deps = [
11+
":operator_validation_utils",
812
"//executorch/backends/arm/debug:schema",
913
"//executorch/backends/arm/tosa:mapping",
1014
"//executorch/backends/arm/tosa:tosa",

backends/arm/operators/node_visitor.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
import json
1414

1515
import logging
16-
from typing import Any, Dict, List, Optional
16+
from typing import Any, Callable, Dict, List, Optional
1717

1818
import torch
1919
import tosa_serializer as ts
2020

2121
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
2222
from executorch.backends.arm.debug.schema import DebugHook
23+
from executorch.backends.arm.operators.operator_validation_utils import (
24+
validate_num_inputs,
25+
validate_same_dtype,
26+
validate_valid_dtype,
27+
)
2328
from executorch.backends.arm.tosa.mapping import TosaArg
2429
from executorch.backends.arm.tosa.specification import (
2530
TosaSpecification,
@@ -102,6 +107,77 @@ def _serialize_operator(
102107
location=op_location,
103108
)
104109

110+
def validate(
111+
self,
112+
*,
113+
target: str,
114+
inputs: List[TosaArg],
115+
output: TosaArg,
116+
num_inputs: int | List[int],
117+
input_dtypes: List[Any],
118+
output_dtypes: Optional[List[Any]] = None,
119+
same_dtype_with_output: bool = True,
120+
dtype_check_inputs_only: bool = False,
121+
) -> None:
122+
validate_num_inputs(target, inputs, num_inputs)
123+
if same_dtype_with_output:
124+
validate_same_dtype(target, [*inputs, output], ts)
125+
else:
126+
validate_same_dtype(target, inputs, ts)
127+
128+
dtype_check_tensors = inputs if dtype_check_inputs_only else [*inputs, output]
129+
validate_valid_dtype(
130+
target,
131+
dtype_check_tensors,
132+
input_dtypes,
133+
self.tosa_spec,
134+
)
135+
if output_dtypes is not None:
136+
validate_valid_dtype(
137+
target,
138+
output,
139+
output_dtypes,
140+
self.tosa_spec,
141+
)
142+
143+
def serialize(
144+
self,
145+
node: torch.fx.Node,
146+
tosa_graph: Any,
147+
*,
148+
tosa_op: ts.Op,
149+
inputs: List[TosaArg],
150+
output: TosaArg,
151+
attr_method: Optional[str] = None,
152+
attr_kwargs: Optional[dict[str, Any]] = None,
153+
attr_builder: Optional[Callable[[ts.TosaSerializerAttribute], None]] = None,
154+
extra_input_builders: Optional[
155+
List[Callable[[torch.fx.Node, Any, List[TosaArg], TosaArg, Any], str]]
156+
] = None,
157+
) -> None:
158+
attr = ts.TosaSerializerAttribute()
159+
if attr_method is not None:
160+
getattr(attr, attr_method)(**(attr_kwargs or {}))
161+
elif attr_builder is not None:
162+
attr_builder(attr)
163+
else:
164+
raise NotImplementedError(
165+
f"{self.__class__.__name__} must define attr_method or attr_builder."
166+
)
167+
input_names = [arg.name for arg in inputs]
168+
for builder in extra_input_builders or []:
169+
input_names.append(
170+
builder(node, tosa_graph, inputs, output, self.tosa_spec)
171+
)
172+
self._serialize_operator(
173+
node,
174+
tosa_graph,
175+
tosa_op,
176+
input_names,
177+
[output.name],
178+
attr,
179+
)
180+
105181
def define_node(
106182
self,
107183
node: torch.fx.Node,
@@ -151,6 +227,9 @@ def register_node_visitor(visitor):
151227

152228
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
153229
"""Return a mapping from target names to visitor instances for a spec."""
230+
# Ensure all operator modules are imported so visitors are registered.
231+
import executorch.backends.arm.operators # noqa: F401
232+
154233
node_visitors: Dict[str, NodeVisitor] = {}
155234
tosa_spec: TosaSpecification | None = None
156235
for arg in args:

backends/arm/operators/op_abs.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,24 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Any, List
7-
86
import tosa_serializer as ts
97

10-
from executorch.backends.arm.operators.node_visitor import (
11-
NodeVisitor,
12-
register_node_visitor,
13-
)
14-
from executorch.backends.arm.operators.operator_validation_utils import (
15-
validate_num_inputs,
16-
validate_same_dtype,
17-
validate_valid_dtype,
8+
from executorch.backends.arm.operators.node_visitor import register_node_visitor
9+
from executorch.backends.arm.operators.simple_node_visitor import (
10+
SimpleNodeVisitor,
11+
SimpleNodeVisitorConfig,
1812
)
19-
from executorch.backends.arm.tosa.mapping import TosaArg
20-
from torch.fx import Node
2113

2214

2315
@register_node_visitor
24-
class AbsVisitor(NodeVisitor):
16+
class AbsVisitor(SimpleNodeVisitor):
2517
target = "aten.abs.default"
2618

27-
def define_node(
28-
self,
29-
node: Node,
30-
tosa_graph: Any,
31-
inputs: List[TosaArg],
32-
output: TosaArg,
33-
) -> None:
34-
validate_num_inputs(self.target, inputs, 1)
35-
validate_same_dtype(self.target, [*inputs, output], ts)
36-
37-
validate_valid_dtype(
38-
self.target,
39-
[*inputs, output],
40-
[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
41-
self.tosa_spec,
42-
)
43-
44-
attr = ts.TosaSerializerAttribute()
45-
attr.AbsAttribute()
46-
self._serialize_operator(
47-
node,
48-
tosa_graph,
49-
ts.Op.ABS,
50-
[inputs[0].name],
51-
[output.name],
52-
attr,
19+
@classmethod
20+
def get_config(cls) -> SimpleNodeVisitorConfig:
21+
return SimpleNodeVisitorConfig(
22+
tosa_op=ts.Op.ABS,
23+
attr_method="AbsAttribute",
24+
num_inputs=1,
25+
input_dtypes=[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
5326
)

backends/arm/operators/op_add.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,24 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
7-
from typing import Any, List
8-
96
import tosa_serializer as ts
107

11-
from executorch.backends.arm.operators.node_visitor import (
12-
NodeVisitor,
13-
register_node_visitor,
14-
)
15-
from executorch.backends.arm.operators.operator_validation_utils import (
16-
validate_num_inputs,
17-
validate_same_dtype,
18-
validate_valid_dtype,
8+
from executorch.backends.arm.operators.node_visitor import register_node_visitor
9+
from executorch.backends.arm.operators.simple_node_visitor import (
10+
SimpleNodeVisitor,
11+
SimpleNodeVisitorConfig,
1912
)
20-
from executorch.backends.arm.tosa.mapping import TosaArg
21-
from torch.fx import Node
2213

2314

2415
@register_node_visitor
25-
class AddVisitor(NodeVisitor):
16+
class AddVisitor(SimpleNodeVisitor):
2617
target = "aten.add.Tensor"
2718

28-
def define_node(
29-
self,
30-
node: Node,
31-
tosa_graph: Any,
32-
inputs: List[TosaArg],
33-
output: TosaArg,
34-
) -> None:
35-
validate_num_inputs(self.target, inputs, 2)
36-
validate_same_dtype(self.target, [*inputs, output], ts)
37-
validate_valid_dtype(
38-
self.target,
39-
[*inputs, output],
40-
[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
41-
self.tosa_spec,
42-
)
43-
44-
attr = ts.TosaSerializerAttribute()
45-
attr.AddAttribute()
46-
47-
self._serialize_operator(
48-
node,
49-
tosa_graph,
50-
ts.Op.ADD,
51-
[inputs[0].name, inputs[1].name],
52-
[output.name],
53-
attr,
19+
@classmethod
20+
def get_config(cls) -> SimpleNodeVisitorConfig:
21+
return SimpleNodeVisitorConfig(
22+
tosa_op=ts.Op.ADD,
23+
attr_method="AddAttribute",
24+
num_inputs=2,
25+
input_dtypes=[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
5426
)

backends/arm/operators/op_bitwise_not.py

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,28 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Any, List
7-
86
import tosa_serializer as ts
97

10-
from executorch.backends.arm.operators.node_visitor import (
11-
NodeVisitor,
12-
register_node_visitor,
13-
)
14-
from executorch.backends.arm.operators.operator_validation_utils import (
15-
validate_num_inputs,
16-
validate_same_dtype,
17-
validate_valid_dtype,
8+
from executorch.backends.arm.operators.node_visitor import register_node_visitor
9+
from executorch.backends.arm.operators.simple_node_visitor import (
10+
SimpleNodeVisitor,
11+
SimpleNodeVisitorConfig,
1812
)
19-
from executorch.backends.arm.tosa.mapping import TosaArg
20-
from executorch.backends.arm.tosa.specification import TosaSpecification
21-
from torch.fx import Node
13+
from executorch.backends.arm.tosa import TosaSpecification
14+
15+
INT_SPECS = TosaSpecification.all_versions_for_profile("INT")
2216

2317

2418
@register_node_visitor
25-
class BitwiseNotVisitor(NodeVisitor):
19+
class BitwiseNotVisitor(SimpleNodeVisitor):
2620
target = "aten.bitwise_not.default"
27-
28-
# bitwise_not is not supported on the FP profile
29-
tosa_specs = TosaSpecification.all_versions_for_profile("INT")
30-
31-
def __init__(self, *args):
32-
super().__init__(*args)
33-
34-
def define_node(
35-
self,
36-
node: Node,
37-
tosa_graph: Any,
38-
inputs: List[TosaArg],
39-
output: TosaArg,
40-
) -> None:
41-
validate_num_inputs(self.target, inputs, 1)
42-
validate_same_dtype(self.target, [*inputs, output], ts)
43-
validate_valid_dtype(
44-
self.target,
45-
[*inputs, output],
46-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
47-
self.tosa_spec,
48-
)
49-
50-
attr = ts.TosaSerializerAttribute()
51-
attr.BitwiseNotAttribute()
52-
53-
self._serialize_operator(
54-
node,
55-
tosa_graph,
56-
ts.Op.BITWISE_NOT,
57-
[inputs[0].name],
58-
[output.name],
59-
attr,
21+
tosa_specs = INT_SPECS
22+
23+
@classmethod
24+
def get_config(cls) -> SimpleNodeVisitorConfig:
25+
return SimpleNodeVisitorConfig(
26+
tosa_op=ts.Op.BITWISE_NOT,
27+
attr_method="BitwiseNotAttribute",
28+
num_inputs=1,
29+
input_dtypes=[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
6030
)

backends/arm/operators/op_ceil.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,54 +3,28 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import Any, List
7-
8-
import torch.fx
9-
106
import tosa_serializer as ts
117

12-
from executorch.backends.arm.operators.node_visitor import (
13-
NodeVisitor,
14-
register_node_visitor,
15-
)
16-
from executorch.backends.arm.operators.operator_validation_utils import (
17-
validate_num_inputs,
18-
validate_same_dtype,
19-
validate_valid_dtype,
8+
from executorch.backends.arm.operators.node_visitor import register_node_visitor
9+
from executorch.backends.arm.operators.simple_node_visitor import (
10+
SimpleNodeVisitor,
11+
SimpleNodeVisitorConfig,
2012
)
2113
from executorch.backends.arm.tosa import TosaSpecification
2214

23-
from executorch.backends.arm.tosa.mapping import TosaArg
15+
FP_SPECS = TosaSpecification.all_versions_for_profile("FP")
2416

2517

2618
@register_node_visitor
27-
class CeilVisitor(NodeVisitor):
19+
class CeilVisitor(SimpleNodeVisitor):
2820
target = "aten.ceil.default"
29-
30-
# INT case should be handled by op_table
31-
tosa_specs = TosaSpecification.all_versions_for_profile("FP")
32-
33-
def __init__(self, *args):
34-
super().__init__(*args)
35-
36-
def define_node(
37-
self,
38-
node: torch.fx.Node,
39-
tosa_graph: Any,
40-
inputs: List[TosaArg],
41-
output: TosaArg,
42-
) -> None:
43-
validate_num_inputs(self.target, inputs, 1)
44-
validate_same_dtype(self.target, [*inputs, output], ts)
45-
validate_valid_dtype(
46-
self.target,
47-
inputs[0],
48-
[ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
49-
self.tosa_spec,
50-
)
51-
52-
attr = ts.TosaSerializerAttribute()
53-
attr.CeilAttribute()
54-
self._serialize_operator(
55-
node, tosa_graph, ts.Op.CEIL, [inputs[0].name], [output.name], attr
21+
tosa_specs = FP_SPECS
22+
23+
@classmethod
24+
def get_config(cls) -> SimpleNodeVisitorConfig:
25+
return SimpleNodeVisitorConfig(
26+
tosa_op=ts.Op.CEIL,
27+
attr_method="CeilAttribute",
28+
num_inputs=1,
29+
input_dtypes=[ts.DType.FP16, ts.DType.FP32, ts.DType.BF16],
5630
)

0 commit comments

Comments
 (0)