Skip to content

Commit dd9741b

Browse files
author
Baris Demir
committed
Arm backend: Clean up grid_sampler custom payload
Refactor the grid_sampler_2d TOSA CUSTOM payload into a helper module so the custom shader contract, constants, and validation logic live in one place. Emit the ML SDK Vulkan custom-shader attribute format, including the converter-recognized domain, entry point, workgroup sizes, shader language/code, tensor storage-buffer bindings, descriptor sets, and Vulkan format metadata. Add focused tests for payload round-trip behavior and for the grid_sampler_2d rewrite pass. Signed-off-by: Baris Demir <baris.demir@arm.com> Change-Id: I8f13f91b97e86b3b32bf2c25479684bc09faeddc
1 parent dd8d03e commit dd9741b

9 files changed

Lines changed: 341 additions & 23 deletions
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
import base64
8+
import shutil
9+
import subprocess # nosec B404 - required to invoke the shader compiler.
10+
import tempfile
11+
from pathlib import Path
12+
13+
14+
SHADER_DIR = Path(__file__).resolve().parents[1] / "vgf" / "shaders"
15+
DEFAULT_SOURCE = SHADER_DIR / "grid_sampler.glsl"
16+
DEFAULT_OUTPUT = SHADER_DIR / "grid_sampler.spirv.b64"
17+
18+
19+
def _parse_args() -> argparse.Namespace:
20+
parser = argparse.ArgumentParser(
21+
description=(
22+
"Compile the VGF grid_sampler GLSL shader to SPIR-V and write the "
23+
"base64-encoded payload consumed by the ExecuTorch custom-shader "
24+
"lowering."
25+
)
26+
)
27+
parser.add_argument(
28+
"--source",
29+
type=Path,
30+
default=DEFAULT_SOURCE,
31+
help=f"GLSL source file. Defaults to {DEFAULT_SOURCE}",
32+
)
33+
parser.add_argument(
34+
"--output",
35+
type=Path,
36+
default=DEFAULT_OUTPUT,
37+
help=f"Base64 SPIR-V output file. Defaults to {DEFAULT_OUTPUT}",
38+
)
39+
parser.add_argument(
40+
"--glslc",
41+
default="glslc",
42+
help="Path to glslc. Defaults to resolving glslc from PATH.",
43+
)
44+
return parser.parse_args()
45+
46+
47+
def _resolve_glslc(glslc: str) -> str:
48+
resolved = shutil.which(glslc)
49+
if resolved is None:
50+
raise RuntimeError(
51+
f"Could not find {glslc}. Install the Vulkan SDK or pass --glslc."
52+
)
53+
return resolved
54+
55+
56+
def _write_base64_spirv(spirv_path: Path, output_path: Path) -> None:
57+
encoded = base64.b64encode(spirv_path.read_bytes()).decode("ascii")
58+
output_path.write_text(encoded + "\n", encoding="utf-8")
59+
60+
61+
def main() -> None:
62+
args = _parse_args()
63+
glslc = _resolve_glslc(args.glslc)
64+
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
spirv_path = Path(tmpdir) / "grid_sampler.spirv"
67+
subprocess.run( # nosec B603 - glslc path is resolved explicitly.
68+
[glslc, str(args.source), "-o", str(spirv_path)],
69+
check=True,
70+
)
71+
_write_base64_spirv(spirv_path, args.output)
72+
73+
74+
if __name__ == "__main__":
75+
main()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import base64
7+
8+
import pytest
9+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
10+
build_grid_sampler_2d_payload,
11+
decode_payload,
12+
encode_payload,
13+
GRID_SAMPLER_2D_SHADER_BINARY,
14+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
15+
GRID_SAMPLER_2D_SHADER_LANGUAGE,
16+
GRID_SAMPLER_2D_SHADER_SOURCE,
17+
GRID_SAMPLER_2D_VK_FORMAT,
18+
GRID_SAMPLER_2D_WORKGROUP_SIZES,
19+
)
20+
21+
22+
def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip():
23+
payload = build_grid_sampler_2d_payload(
24+
interpolation_mode=0,
25+
padding_mode=2,
26+
align_corners=True,
27+
)
28+
decoded = decode_payload(encode_payload(payload))
29+
30+
assert decoded["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
31+
assert decoded["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
32+
assert decoded["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
33+
assert base64.b64decode(decoded["shader_code"])[:4] == b"\x03\x02\x23\x07"
34+
assert decoded["input_0_type"] == "Tensor"
35+
assert decoded["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
36+
assert decoded["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
37+
assert decoded["input_0_binding"] == 0
38+
assert decoded["input_1_type"] == "Tensor"
39+
assert decoded["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
40+
assert decoded["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
41+
assert decoded["input_1_binding"] == 1
42+
assert decoded["output_0_type"] == "Tensor"
43+
assert decoded["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
44+
assert decoded["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
45+
assert decoded["output_0_binding"] == 2
46+
47+
48+
def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv():
49+
payload = build_grid_sampler_2d_payload(
50+
interpolation_mode=0,
51+
padding_mode=0,
52+
align_corners=False,
53+
)
54+
55+
shader_binary = base64.b64decode(payload["shader_code"])
56+
57+
assert payload["shader_language"] == "SPIR-V"
58+
assert shader_binary[:4] == b"\x03\x02\x23\x07"
59+
60+
61+
def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources():
62+
assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl"
63+
assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64"
64+
65+
66+
def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes():
67+
with pytest.raises(ValueError, match="Unsupported interpolation_mode"):
68+
build_grid_sampler_2d_payload(
69+
interpolation_mode=99,
70+
padding_mode=0,
71+
align_corners=False,
72+
)
73+
74+
with pytest.raises(ValueError, match="Unsupported padding_mode"):
75+
build_grid_sampler_2d_payload(
76+
interpolation_mode=0,
77+
padding_mode=99,
78+
align_corners=False,
79+
)

backends/arm/test/passes/test_rewrite_grid_sampler_to_tosa_custom_pass.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import json
7-
86
import executorch.backends.arm.tosa.dialect # noqa: F401
97
import torch
108
import torch.nn.functional as F
@@ -15,6 +13,15 @@
1513
from executorch.backends.arm.vgf._passes.rewrite_grid_sampler_to_tosa_custom import (
1614
RewriteGridSamplerToTosaCustomPass,
1715
)
16+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
17+
CUSTOM_SHADER_DOMAIN_NAME,
18+
decode_payload,
19+
GRID_SAMPLER_2D_OPERATOR_NAME,
20+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
21+
GRID_SAMPLER_2D_SHADER_LANGUAGE,
22+
GRID_SAMPLER_2D_VK_FORMAT,
23+
GRID_SAMPLER_2D_WORKGROUP_SIZES,
24+
)
1825
from executorch.exir import to_edge
1926
from executorch.exir.dialects._ops import ops as exir_ops
2027
from torch.export import export
@@ -62,11 +69,22 @@ def test_rewrite_grid_sampler_to_tosa_custom_no_target():
6269
custom_node = next(
6370
node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default
6471
)
65-
assert custom_node.kwargs["operator_name"] == "grid_sampler_2d"
66-
assert custom_node.kwargs["domain_name"] == "arm.custom_shader"
72+
assert custom_node.kwargs["operator_name"] == GRID_SAMPLER_2D_OPERATOR_NAME
73+
assert custom_node.kwargs["domain_name"] == CUSTOM_SHADER_DOMAIN_NAME
6774

68-
payload = json.loads(
69-
bytes(custom_node.kwargs["implementation_attrs"]).decode("utf-8")
70-
)
71-
assert payload["op"] == "grid_sampler_2d"
72-
assert payload["shader"]["encoding"] == "placeholder"
75+
payload = decode_payload(custom_node.kwargs["implementation_attrs"])
76+
assert payload["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
77+
assert payload["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
78+
assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
79+
assert payload["input_0_type"] == "Tensor"
80+
assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
81+
assert payload["input_0_binding"] == 0
82+
assert payload["input_0_descriptorset"] == 0
83+
assert payload["input_1_type"] == "Tensor"
84+
assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
85+
assert payload["input_1_binding"] == 1
86+
assert payload["input_1_descriptorset"] == 0
87+
assert payload["output_0_type"] == "Tensor"
88+
assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
89+
assert payload["output_0_binding"] == 2
90+
assert payload["output_0_descriptorset"] == 0

backends/arm/vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import json
76
import operator
87
from typing import Set, Type
98

109
import torch
1110
from executorch.backends.arm._passes import ArmPass
1211
from executorch.backends.arm._passes.arm_pass_utils import create_node
1312
from executorch.backends.arm.tosa.dialect.ops.custom import register_fake_tosa
13+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
14+
build_grid_sampler_2d_payload,
15+
CUSTOM_SHADER_DOMAIN_NAME,
16+
encode_payload,
17+
GRID_SAMPLER_2D_OPERATOR_NAME,
18+
)
1419
from executorch.exir.dialects._ops import ops as exir_ops
1520
from executorch.exir.pass_base import ExportPass, PassResult
1621

1722

18-
@register_fake_tosa("grid_sampler_2d")
23+
@register_fake_tosa(GRID_SAMPLER_2D_OPERATOR_NAME)
1924
def _grid_sampler_2d_custom_fake_impl(
2025
inputs, operator_name, domain_name, implementation_attrs
2126
) -> list[torch.Tensor]:
@@ -46,16 +51,12 @@ class RewriteGridSamplerToTosaCustomPass(ArmPass):
4651
def _encode_payload(
4752
interpolation_mode: int, padding_mode: int, align_corners: bool
4853
) -> list[int]:
49-
payload = {
50-
"version": 1,
51-
"type": "arm_custom_shader",
52-
"op": "grid_sampler_2d",
53-
"interpolation_mode": int(interpolation_mode),
54-
"padding_mode": int(padding_mode),
55-
"align_corners": bool(align_corners),
56-
"shader": {"encoding": "placeholder", "entry_point": "main"},
57-
}
58-
return list(json.dumps(payload, sort_keys=True).encode("utf-8"))
54+
payload = build_grid_sampler_2d_payload(
55+
interpolation_mode=interpolation_mode,
56+
padding_mode=padding_mode,
57+
align_corners=align_corners,
58+
)
59+
return encode_payload(payload)
5960

6061
def call(self, graph_module):
6162
modified = False
@@ -83,8 +84,8 @@ def call(self, graph_module):
8384
op_target=exir_ops.backend.tosa.CUSTOM.default,
8485
args=([input_tensor, grid],),
8586
kwargs={
86-
"operator_name": "grid_sampler_2d",
87-
"domain_name": "arm.custom_shader",
87+
"operator_name": GRID_SAMPLER_2D_OPERATOR_NAME,
88+
"domain_name": CUSTOM_SHADER_DOMAIN_NAME,
8889
"implementation_attrs": implementation_attrs,
8990
},
9091
from_node=node,
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#version 450
2+
3+
layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
4+
5+
layout(set = 0, binding = 0) readonly buffer Input0 {
6+
float input0[];
7+
};
8+
9+
layout(set = 0, binding = 1) readonly buffer Input1 {
10+
float input1[];
11+
};
12+
13+
layout(set = 0, binding = 2) writeonly buffer Output0 {
14+
float output0[];
15+
};
16+
17+
void main() {
18+
uint index = gl_GlobalInvocationID.x;
19+
output0[index] = input0[index];
20+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import json
7+
from importlib.resources import files
8+
from typing import Any
9+
10+
CUSTOM_SHADER_DOMAIN_NAME = "com.arm.VulkanCustomShader"
11+
GRID_SAMPLER_2D_OPERATOR_NAME = "torch.nn.functional.grid_sample"
12+
GRID_SAMPLER_2D_WORKGROUP_SIZES = [8, 8, 1]
13+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT = "main"
14+
GRID_SAMPLER_2D_SHADER_LANGUAGE = "SPIR-V"
15+
GRID_SAMPLER_2D_VK_FORMAT = "VK_FORMAT_R32_SFLOAT"
16+
GRID_SAMPLER_2D_SHADER_SOURCE = "grid_sampler.glsl"
17+
GRID_SAMPLER_2D_SHADER_BINARY = "grid_sampler.spirv.b64"
18+
19+
_INTERPOLATION_MODE_NAMES = {
20+
0: "bilinear",
21+
1: "nearest",
22+
2: "bicubic",
23+
}
24+
_PADDING_MODE_NAMES = {
25+
0: "zeros",
26+
1: "border",
27+
2: "reflection",
28+
}
29+
30+
31+
def _mode_name(
32+
mode: int,
33+
names: dict[int, str],
34+
mode_kind: str,
35+
) -> str:
36+
if mode not in names:
37+
raise ValueError(
38+
f"Unsupported {mode_kind} {mode} for {GRID_SAMPLER_2D_OPERATOR_NAME}"
39+
)
40+
return names[mode]
41+
42+
43+
def build_grid_sampler_2d_payload(
44+
interpolation_mode: int,
45+
padding_mode: int,
46+
align_corners: bool,
47+
) -> dict[str, Any]:
48+
_mode_name(
49+
int(interpolation_mode),
50+
_INTERPOLATION_MODE_NAMES,
51+
"interpolation_mode",
52+
)
53+
_mode_name(
54+
int(padding_mode),
55+
_PADDING_MODE_NAMES,
56+
"padding_mode",
57+
)
58+
shader_code = "".join(
59+
files(__package__)
60+
.joinpath(GRID_SAMPLER_2D_SHADER_BINARY)
61+
.read_text(encoding="utf-8")
62+
.split()
63+
)
64+
65+
return {
66+
"entry_point": GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
67+
"workgroup_sizes": GRID_SAMPLER_2D_WORKGROUP_SIZES,
68+
"shader_language": GRID_SAMPLER_2D_SHADER_LANGUAGE,
69+
"shader_code": shader_code,
70+
"input_0_type": "Tensor",
71+
"input_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT,
72+
"input_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
73+
"input_0_binding": 0,
74+
"input_0_descriptorset": 0,
75+
"input_1_type": "Tensor",
76+
"input_1_vkformat": GRID_SAMPLER_2D_VK_FORMAT,
77+
"input_1_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
78+
"input_1_binding": 1,
79+
"input_1_descriptorset": 0,
80+
"output_0_type": "Tensor",
81+
"output_0_vkformat": GRID_SAMPLER_2D_VK_FORMAT,
82+
"output_0_vkdescriptortype": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
83+
"output_0_binding": 2,
84+
"output_0_descriptorset": 0,
85+
}
86+
87+
88+
def encode_payload(payload: dict[str, Any]) -> list[int]:
89+
return list(json.dumps(payload, sort_keys=True).encode("utf-8"))
90+
91+
92+
def decode_payload(implementation_attrs: list[int]) -> dict[str, Any]:
93+
return json.loads(bytes(implementation_attrs).decode("utf-8"))

0 commit comments

Comments
 (0)