|
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | | -import json |
7 | 6 | import operator |
8 | 7 | from typing import Set, Type |
9 | 8 |
|
10 | 9 | import torch |
11 | 10 | from executorch.backends.arm._passes import ArmPass |
12 | 11 | from executorch.backends.arm._passes.arm_pass_utils import create_node |
13 | 12 | from executorch.backends.arm.tosa.dialect.ops.custom import register_fake_tosa |
| 13 | +from executorch.backends.arm.vgf.shaders.grid_sampler import ( |
| 14 | + build_grid_sampler_2d_payload, |
| 15 | + CUSTOM_SHADER_DOMAIN_NAME, |
| 16 | + encode_payload, |
| 17 | + GRID_SAMPLER_2D_OPERATOR_NAME, |
| 18 | +) |
14 | 19 | from executorch.exir.dialects._ops import ops as exir_ops |
15 | 20 | from executorch.exir.pass_base import ExportPass, PassResult |
16 | 21 |
|
17 | 22 |
|
18 | | -@register_fake_tosa("grid_sampler_2d") |
| 23 | +@register_fake_tosa(GRID_SAMPLER_2D_OPERATOR_NAME) |
19 | 24 | def _grid_sampler_2d_custom_fake_impl( |
20 | 25 | inputs, operator_name, domain_name, implementation_attrs |
21 | 26 | ) -> list[torch.Tensor]: |
@@ -46,16 +51,12 @@ class RewriteGridSamplerToTosaCustomPass(ArmPass): |
46 | 51 | def _encode_payload( |
47 | 52 | interpolation_mode: int, padding_mode: int, align_corners: bool |
48 | 53 | ) -> list[int]: |
49 | | - payload = { |
50 | | - "version": 1, |
51 | | - "type": "arm_custom_shader", |
52 | | - "op": "grid_sampler_2d", |
53 | | - "interpolation_mode": int(interpolation_mode), |
54 | | - "padding_mode": int(padding_mode), |
55 | | - "align_corners": bool(align_corners), |
56 | | - "shader": {"encoding": "placeholder", "entry_point": "main"}, |
57 | | - } |
58 | | - return list(json.dumps(payload, sort_keys=True).encode("utf-8")) |
| 54 | + payload = build_grid_sampler_2d_payload( |
| 55 | + interpolation_mode=interpolation_mode, |
| 56 | + padding_mode=padding_mode, |
| 57 | + align_corners=align_corners, |
| 58 | + ) |
| 59 | + return encode_payload(payload) |
59 | 60 |
|
60 | 61 | def call(self, graph_module): |
61 | 62 | modified = False |
@@ -83,8 +84,8 @@ def call(self, graph_module): |
83 | 84 | op_target=exir_ops.backend.tosa.CUSTOM.default, |
84 | 85 | args=([input_tensor, grid],), |
85 | 86 | kwargs={ |
86 | | - "operator_name": "grid_sampler_2d", |
87 | | - "domain_name": "arm.custom_shader", |
| 87 | + "operator_name": GRID_SAMPLER_2D_OPERATOR_NAME, |
| 88 | + "domain_name": CUSTOM_SHADER_DOMAIN_NAME, |
88 | 89 | "implementation_attrs": implementation_attrs, |
89 | 90 | }, |
90 | 91 | from_node=node, |
|
0 commit comments