@@ -281,8 +281,9 @@ def layout_declare_buffer(
281281 dtype : str ,
282282 precision : str = "PRECISION" ,
283283 is_scalar_array : bool = True ,
284+ vec_size : int = 4 ,
284285) -> str :
285- array_type = buffer_gvec_type (dtype , 4 )
286+ array_type = buffer_gvec_type (dtype , vec_size )
286287 if is_scalar_array :
287288 array_type = buffer_scalar_type (dtype )
288289
@@ -341,6 +342,7 @@ def layout_declare_tensor(
341342 storage_type : str ,
342343 is_scalar_array : bool = True ,
343344 precision : str = "PRECISION" ,
345+ vec_size : int = 4 ,
344346) -> str :
345347 assert storage_type .lower () in ["buffer" , "texture3d" , "texture2d" ]
346348
@@ -357,6 +359,7 @@ def layout_declare_tensor(
357359 dtype ,
358360 precision ,
359361 is_scalar_array = is_scalar_array ,
362+ vec_size = vec_size ,
360363 )
361364
362365 # Create image/sampler binding
@@ -785,6 +788,10 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
785788 "generate_variant_forall" , None
786789 )
787790
791+ reserved_yaml_keys = {
792+ "generate_variant_forall" ,
793+ }
794+
788795 for variant in params_dict ["shader_variants" ]:
789796 default_iterated_params_names = set (
790797 default_iterated_params .keys ()
@@ -797,7 +804,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
797804 variant_params_names
798805 - default_iterated_params_names
799806 - params_names
800- - { "generate_variant_forall" }
807+ - reserved_yaml_keys
801808 )
802809 assert len (invalid_keys ) == 0
803810
@@ -813,7 +820,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
813820 for combination in variant_combinations :
814821 default_params_copy = copy .deepcopy (default_params )
815822 for key in variant :
816- if key != "generate_variant_forall" :
823+ if key not in reserved_yaml_keys :
817824 default_params_copy [key ] = variant [key ]
818825
819826 variant_name = variant ["NAME" ]
@@ -842,7 +849,8 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
842849 else :
843850 default_params_copy = copy .deepcopy (default_params )
844851 for key in variant :
845- default_params_copy [key ] = variant [key ]
852+ if key not in reserved_yaml_keys :
853+ default_params_copy [key ] = variant [key ]
846854
847855 self .shader_template_params [template_name ].append (
848856 default_params_copy
@@ -1026,6 +1034,27 @@ def generate_src_file(shader_paths_pair) -> Tuple[bool, List[str]]:
10261034 print (f"template_file_path: { template_file_path } " )
10271035 output_text = preprocess (input_text , codegen_params )
10281036
1037+ # If the shader yaml declared a SUBGROUP_SIZE template parameter,
1038+ # embed it into the generated GLSL as a comment. getShaderInfo()
1039+ # parses it back out alongside TILE_SIZE, WEIGHT_STORAGE, etc.,
1040+ # avoiding a side-channel name -> value map.
1041+ subgroup_size = codegen_params .get ("SUBGROUP_SIZE" )
1042+ if subgroup_size is not None :
1043+ try :
1044+ subgroup_size_int = int (subgroup_size )
1045+ except (TypeError , ValueError ) as e :
1046+ raise RuntimeError (
1047+ f"Shader variant { src_file_name !r} declared "
1048+ f"SUBGROUP_SIZE={ subgroup_size !r} , which is not "
1049+ f"parseable as an integer. Fix the SUBGROUP_SIZE "
1050+ f"value in the shader's yaml."
1051+ ) from e
1052+ if subgroup_size_int > 0 :
1053+ output_text = (
1054+ f"// REQUIRED_SUBGROUP_SIZE = { subgroup_size_int } \n "
1055+ + output_text
1056+ )
1057+
10291058 included_files = get_glsl_includes (output_text )
10301059
10311060 with codecs .open (gen_out_path , "w" , encoding = "utf-8" ) as output_file :
@@ -1184,6 +1213,12 @@ class ShaderInfo:
11841213 requires_integer_dot_product_ext : bool = False
11851214 requires_shader_int64_ext : bool = False
11861215 requires_shader_float64_ext : bool = False
1216+ # Subgroup size requirement (matches the C++ ShaderInfo encoding):
1217+ # 0 = no requirement
1218+ # >0 = literal fixed size; sourced from the shader yaml's
1219+ # `SUBGROUP_SIZE` template parameter (single source of truth for
1220+ # both GLSL substitution and the Vulkan pipeline pin).
1221+ required_subgroup_size : int = 0
11871222
11881223
11891224def getName (filePath : str ) -> str :
@@ -1208,6 +1243,17 @@ def findTileSizes(lineStr: str) -> List[int]:
12081243 return [int (matches .group (1 )), int (matches .group (2 )), int (matches .group (3 ))]
12091244
12101245
1246+ def isRequiredSubgroupSizeLine (lineStr : str ) -> bool :
1247+ return re .search (r"^// REQUIRED_SUBGROUP_SIZE = " , lineStr ) is not None
1248+
1249+
1250+ def findRequiredSubgroupSize (lineStr : str ) -> int :
1251+ matches = re .search (r"^// REQUIRED_SUBGROUP_SIZE = ([0-9]+)" , lineStr )
1252+ if matches is None :
1253+ raise AssertionError ("matches is None in findRequiredSubgroupSize" )
1254+ return int (matches .group (1 ))
1255+
1256+
12111257def isWeightStorageTypeLine (lineStr : str ) -> bool :
12121258 weight_storage_id = r"^ \* WEIGHT_STORAGE = "
12131259 return re .search (weight_storage_id , lineStr ) is not None
@@ -1281,6 +1327,8 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
12811327 shader_info .layouts .append (determineDescriptorType (line ))
12821328 if isTileSizeLine (line ):
12831329 shader_info .tile_size = findTileSizes (line )
1330+ if isRequiredSubgroupSizeLine (line ):
1331+ shader_info .required_subgroup_size = findRequiredSubgroupSize (line )
12841332 if isWeightStorageTypeLine (line ):
12851333 shader_info .weight_storage_type = getWeightStorageType (line )
12861334 if isBiasStorageTypeLine (line ):
@@ -1378,6 +1426,7 @@ def to_cpp_str(val: bool):
13781426 to_cpp_str (shader_info .requires_integer_dot_product_ext ),
13791427 to_cpp_str (shader_info .requires_shader_int64_ext ),
13801428 to_cpp_str (shader_info .requires_shader_float64_ext ),
1429+ str (shader_info .required_subgroup_size ),
13811430 ]
13821431
13831432 shader_info_str = textwrap .indent (
@@ -1406,7 +1455,9 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
14061455
14071456
14081457def genCppFiles (
1409- spv_files : Dict [str , str ], cpp_header_path : str , cpp_src_file_path : str
1458+ spv_files : Dict [str , str ],
1459+ cpp_header_path : str ,
1460+ cpp_src_file_path : str ,
14101461) -> None :
14111462 spv_bin_strs = []
14121463 register_shader_info_strs = []
0 commit comments