@@ -663,6 +663,7 @@ def __init__(
663663 glslc_path : Optional [str ],
664664 glslc_flags : str = "" ,
665665 replace_u16vecn : bool = False ,
666+ fp16_precision : str = "highp" ,
666667 ) -> None :
667668 if isinstance (src_dir_paths , str ):
668669 self .src_dir_paths = [src_dir_paths ]
@@ -678,6 +679,7 @@ def __init__(
678679 if "-Os" in self .glslc_flags_no_opt :
679680 self .glslc_flags_no_opt .remove ("-Os" )
680681 self .replace_u16vecn = replace_u16vecn
682+ self .fp16_precision = fp16_precision
681683
682684 self .src_files : Dict [str , str ] = {}
683685 self .template_yaml_files : List [str ] = []
@@ -857,6 +859,17 @@ def create_shader_params(
857859 for key , value in variant_params .items ():
858860 shader_params [key ] = value
859861
862+ # Optionally override PRECISION for half-precision variants. GLSL
863+ # `mediump` is a hint the driver may use fp16 ALUs for arithmetic.
864+ # On Mali GPUs it's typically honored; on Adreno it's typically
865+ # ignored (harmless). Default is `highp` to match upstream behavior.
866+ if (
867+ self .fp16_precision != "highp"
868+ and shader_params .get ("DTYPE" ) == "half"
869+ and shader_params .get ("PRECISION" ) == "highp"
870+ ):
871+ shader_params ["PRECISION" ] = self .fp16_precision
872+
860873 return shader_params
861874
862875 def constructOutputMap (self ) -> None :
@@ -1488,6 +1501,16 @@ def main(argv: List[str]) -> int:
14881501 default = - 1 ,
14891502 help = "Number of threads for shader compilation. -1 (default) uses all available CPU cores, 1 uses sequential compilation." ,
14901503 )
1504+ parser .add_argument (
1505+ "--fp16-precision" ,
1506+ choices = ["highp" , "mediump" , "lowp" ],
1507+ default = "highp" ,
1508+ help = (
1509+ "GLSL PRECISION qualifier for DTYPE=half shader variants. "
1510+ "`mediump` lets drivers (notably Mali) use fp16 ALUs for arithmetic. "
1511+ "Default `highp` matches upstream behavior. Ignored on fp32 variants."
1512+ ),
1513+ )
14911514 options = parser .parse_args ()
14921515
14931516 env = DEFAULT_ENV
@@ -1520,6 +1543,7 @@ def main(argv: List[str]) -> int:
15201543 options .glslc_path ,
15211544 glslc_flags = glslc_flags_str ,
15221545 replace_u16vecn = options .replace_u16vecn ,
1546+ fp16_precision = options .fp16_precision ,
15231547 )
15241548 output_spv_files = shader_generator .generateSPV (
15251549 options .output_path ,
0 commit comments