Skip to content

Commit 0f471a6

Browse files
committed
Merge branch 'jgibson/vulkan-fp16-precision-option' into polycam
2 parents 5223ba6 + 7423e42 commit 0f471a6

4 files changed

Lines changed: 39 additions & 0 deletions

File tree

backends/vulkan/cmake/ShaderLibrary.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ function(gen_vulkan_shader_lib_cpp shaders_path)
6060
)
6161
endif()
6262

63+
# Allow overriding GLSL PRECISION for fp16 shader variants. Empty / unset
64+
# keeps upstream default (`highp`). Accepted values: highp, mediump, lowp.
65+
if(EXECUTORCH_VULKAN_FP16_PRECISION)
66+
list(APPEND GEN_SPV_ARGS "--fp16-precision"
67+
"${EXECUTORCH_VULKAN_FP16_PRECISION}"
68+
)
69+
endif()
70+
6371
# Ninja cannot expand wildcards (*) in DEPENDS lists.
6472
file(GLOB VULKAN_SHADERS "${shaders_path}/*.glsl" "${shaders_path}/*.glslh"
6573
"${shaders_path}/*.yaml" "${shaders_path}/*.h"

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ def __init__(
663663
glslc_path: Optional[str],
664664
glslc_flags: str = "",
665665
replace_u16vecn: bool = False,
666+
fp16_precision: str = "highp",
666667
) -> None:
667668
if isinstance(src_dir_paths, str):
668669
self.src_dir_paths = [src_dir_paths]
@@ -678,6 +679,7 @@ def __init__(
678679
if "-Os" in self.glslc_flags_no_opt:
679680
self.glslc_flags_no_opt.remove("-Os")
680681
self.replace_u16vecn = replace_u16vecn
682+
self.fp16_precision = fp16_precision
681683

682684
self.src_files: Dict[str, str] = {}
683685
self.template_yaml_files: List[str] = []
@@ -857,6 +859,17 @@ def create_shader_params(
857859
for key, value in variant_params.items():
858860
shader_params[key] = value
859861

862+
# Optionally override PRECISION for half-precision variants. GLSL
863+
# `mediump` is a hint the driver may use fp16 ALUs for arithmetic.
864+
# On Mali GPUs it's typically honored; on Adreno it's typically
865+
# ignored (harmless). Default is `highp` to match upstream behavior.
866+
if (
867+
self.fp16_precision != "highp"
868+
and shader_params.get("DTYPE") == "half"
869+
and shader_params.get("PRECISION") == "highp"
870+
):
871+
shader_params["PRECISION"] = self.fp16_precision
872+
860873
return shader_params
861874

862875
def constructOutputMap(self) -> None:
@@ -1488,6 +1501,16 @@ def main(argv: List[str]) -> int:
14881501
default=-1,
14891502
help="Number of threads for shader compilation. -1 (default) uses all available CPU cores, 1 uses sequential compilation.",
14901503
)
1504+
parser.add_argument(
1505+
"--fp16-precision",
1506+
choices=["highp", "mediump", "lowp"],
1507+
default="highp",
1508+
help=(
1509+
"GLSL PRECISION qualifier for DTYPE=half shader variants. "
1510+
"`mediump` lets drivers (notably Mali) use fp16 ALUs for arithmetic. "
1511+
"Default `highp` matches upstream behavior. Ignored on fp32 variants."
1512+
),
1513+
)
14911514
options = parser.parse_args()
14921515

14931516
env = DEFAULT_ENV
@@ -1520,6 +1543,7 @@ def main(argv: List[str]) -> int:
15201543
options.glslc_path,
15211544
glslc_flags=glslc_flags_str,
15221545
replace_u16vecn=options.replace_u16vecn,
1546+
fp16_precision=options.fp16_precision,
15231547
)
15241548
output_spv_files = shader_generator.generateSPV(
15251549
options.output_path,

scripts/build_android_library.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ build_android_native_library() {
5151
-DEXECUTORCH_BUILD_QNN="${EXECUTORCH_BUILD_QNN}" \
5252
-DQNN_SDK_ROOT="${QNN_SDK_ROOT}" \
5353
-DEXECUTORCH_BUILD_VULKAN="${EXECUTORCH_BUILD_VULKAN}" \
54+
-DEXECUTORCH_VULKAN_FP16_PRECISION="${EXECUTORCH_VULKAN_FP16_PRECISION:-highp}" \
5455
-DXNNPACK_ENABLE_ARM_SME2="${XNNPACK_ENABLE_ARM_SME2}" \
5556
-DFLATCC_ALLOW_WERROR=OFF \
5657
-DSUPPORT_REGEX_LOOKAHEAD=ON \

tools/cmake/preset/default.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ define_overridable_option(
168168
define_overridable_option(
169169
EXECUTORCH_BUILD_VULKAN "Build the Vulkan backend" BOOL OFF
170170
)
171+
define_overridable_option(
172+
EXECUTORCH_VULKAN_FP16_PRECISION
173+
"GLSL PRECISION for Vulkan half-precision shader variants. Accepted values: highp, mediump, lowp. `mediump` lets Mali drivers use fp16 ALUs; ignored on Adreno. Default `highp` matches upstream."
174+
STRING
175+
highp
176+
)
171177
define_overridable_option(
172178
EXECUTORCH_BUILD_PORTABLE_OPS "Build portable_ops library" BOOL ON
173179
)

0 commit comments

Comments
 (0)