|
13 | 13 | import json |
14 | 14 |
|
15 | 15 | import logging |
16 | | -from typing import Any, Dict, List, Optional |
| 16 | +from typing import Any, Callable, Dict, List, Optional |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import tosa_serializer as ts |
20 | 20 |
|
21 | 21 | from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec |
22 | 22 | from executorch.backends.arm.debug.schema import DebugHook |
| 23 | +from executorch.backends.arm.operators.operator_validation_utils import ( |
| 24 | + validate_num_inputs, |
| 25 | + validate_same_dtype, |
| 26 | + validate_valid_dtype, |
| 27 | +) |
23 | 28 | from executorch.backends.arm.tosa.mapping import TosaArg |
24 | 29 | from executorch.backends.arm.tosa.specification import ( |
25 | 30 | TosaSpecification, |
@@ -102,6 +107,77 @@ def _serialize_operator( |
102 | 107 | location=op_location, |
103 | 108 | ) |
104 | 109 |
|
| 110 | + def validate( |
| 111 | + self, |
| 112 | + *, |
| 113 | + target: str, |
| 114 | + inputs: List[TosaArg], |
| 115 | + output: TosaArg, |
| 116 | + num_inputs: int | List[int], |
| 117 | + input_dtypes: List[Any], |
| 118 | + output_dtypes: Optional[List[Any]] = None, |
| 119 | + same_dtype_with_output: bool = True, |
| 120 | + dtype_check_inputs_only: bool = False, |
| 121 | + ) -> None: |
| 122 | + validate_num_inputs(target, inputs, num_inputs) |
| 123 | + if same_dtype_with_output: |
| 124 | + validate_same_dtype(target, [*inputs, output], ts) |
| 125 | + else: |
| 126 | + validate_same_dtype(target, inputs, ts) |
| 127 | + |
| 128 | + dtype_check_tensors = inputs if dtype_check_inputs_only else [*inputs, output] |
| 129 | + validate_valid_dtype( |
| 130 | + target, |
| 131 | + dtype_check_tensors, |
| 132 | + input_dtypes, |
| 133 | + self.tosa_spec, |
| 134 | + ) |
| 135 | + if output_dtypes is not None: |
| 136 | + validate_valid_dtype( |
| 137 | + target, |
| 138 | + output, |
| 139 | + output_dtypes, |
| 140 | + self.tosa_spec, |
| 141 | + ) |
| 142 | + |
| 143 | + def serialize( |
| 144 | + self, |
| 145 | + node: torch.fx.Node, |
| 146 | + tosa_graph: Any, |
| 147 | + *, |
| 148 | + tosa_op: ts.Op, |
| 149 | + inputs: List[TosaArg], |
| 150 | + output: TosaArg, |
| 151 | + attr_method: Optional[str] = None, |
| 152 | + attr_kwargs: Optional[dict[str, Any]] = None, |
| 153 | + attr_builder: Optional[Callable[[ts.TosaSerializerAttribute], None]] = None, |
| 154 | + extra_input_builders: Optional[ |
| 155 | + List[Callable[[torch.fx.Node, Any, List[TosaArg], TosaArg, Any], str]] |
| 156 | + ] = None, |
| 157 | + ) -> None: |
| 158 | + attr = ts.TosaSerializerAttribute() |
| 159 | + if attr_method is not None: |
| 160 | + getattr(attr, attr_method)(**(attr_kwargs or {})) |
| 161 | + elif attr_builder is not None: |
| 162 | + attr_builder(attr) |
| 163 | + else: |
| 164 | + raise NotImplementedError( |
| 165 | + f"{self.__class__.__name__} must define attr_method or attr_builder." |
| 166 | + ) |
| 167 | + input_names = [arg.name for arg in inputs] |
| 168 | + for builder in extra_input_builders or []: |
| 169 | + input_names.append( |
| 170 | + builder(node, tosa_graph, inputs, output, self.tosa_spec) |
| 171 | + ) |
| 172 | + self._serialize_operator( |
| 173 | + node, |
| 174 | + tosa_graph, |
| 175 | + tosa_op, |
| 176 | + input_names, |
| 177 | + [output.name], |
| 178 | + attr, |
| 179 | + ) |
| 180 | + |
105 | 181 | def define_node( |
106 | 182 | self, |
107 | 183 | node: torch.fx.Node, |
@@ -151,6 +227,9 @@ def register_node_visitor(visitor): |
151 | 227 |
|
152 | 228 | def get_node_visitors(*args) -> Dict[str, NodeVisitor]: |
153 | 229 | """Return a mapping from target names to visitor instances for a spec.""" |
| 230 | + # Ensure all operator modules are imported so visitors are registered. |
| 231 | + import executorch.backends.arm.operators # noqa: F401 |
| 232 | + |
154 | 233 | node_visitors: Dict[str, NodeVisitor] = {} |
155 | 234 | tosa_spec: TosaSpecification | None = None |
156 | 235 | for arg in args: |
|
0 commit comments