|
5 | 5 |
|
6 | 6 | import executorch.backends.arm.tosa.dialect # noqa: F401 |
7 | 7 | from executorch.backends.arm._passes.aten_to_tosa_activation_functions import ( |
8 | | - rewrite_clamp, |
9 | | - rewrite_erf, |
10 | | - rewrite_sigmoid, |
11 | | - rewrite_tanh, |
| 8 | + get_activation_replacement, |
| 9 | +) |
| 10 | +from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax |
| 11 | +from executorch.backends.transforms.aten_to_dialect_pass import ( |
| 12 | + AtenToDialectPass, |
| 13 | + DialectNodeSpec, |
12 | 14 | ) |
13 | | -from executorch.backends.transforms.aten_to_dialect_pass import AtenToDialectPass |
14 | 15 | from executorch.exir.dialects._ops import ops as exir_ops |
| 16 | +from torch.fx import Node |
15 | 17 |
|
16 | 18 |
|
17 | 19 | class ExirToTosaPass(AtenToDialectPass): |
18 | 20 | """Rewrite simple EXIR ops to equivalent backend TOSA dialect ops. |
19 | 21 |
|
20 | | - Rewrite functions are grouped by op category and registered with the shared |
21 | | - ATen-to-dialect pass infrastructure. |
| 22 | + Rewrite functions are registered with the shared ATen-to-dialect pass |
| 23 | + infrastructure. |
22 | 24 |
|
23 | 25 | """ |
24 | 26 |
|
25 | 27 |
|
26 | | -_ACTIVATION_FUNCTION_REWRITES = { |
27 | | - exir_ops.edge.aten.clamp.default: rewrite_clamp, |
28 | | - exir_ops.edge.aten.erf.default: rewrite_erf, |
29 | | - exir_ops.edge.aten.sigmoid.default: rewrite_sigmoid, |
30 | | - exir_ops.edge.aten.tanh.default: rewrite_tanh, |
31 | | -} |
| 28 | +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default) |
| 29 | +def _get_tensor_operators_replacement( |
| 30 | + node: Node, pass_: AtenToDialectPass |
| 31 | +) -> DialectNodeSpec: |
| 32 | + return rewrite_argmax(node, pass_) |
32 | 33 |
|
33 | | -_DIRECT_REWRITE_CATEGORIES = { |
34 | | - "activation_functions": _ACTIVATION_FUNCTION_REWRITES, |
35 | | -} |
36 | 34 |
|
37 | | -# Register each category's ATen targets with the function that builds the |
38 | | -# corresponding TOSA dialect node spec. |
39 | | -for _rewrite_category in _DIRECT_REWRITE_CATEGORIES.values(): |
40 | | - for _edge_target, _rewrite_fn in _rewrite_category.items(): |
41 | | - ExirToTosaPass.register_dialect_substitution(_edge_target)(_rewrite_fn) |
| 35 | +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default) |
| 36 | +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default) |
| 37 | +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default) |
| 38 | +@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default) |
| 39 | +def _get_activation_replacement( |
| 40 | + node: Node, pass_: AtenToDialectPass |
| 41 | +) -> DialectNodeSpec | None: |
| 42 | + return get_activation_replacement(node, pass_) |
0 commit comments