Skip to content

Commit 1f8a6e6

Browse files
authored
[ET-VK] Plumb subgroup property queries + VK_EXT_subgroup_size_control
Differential Revision: D104456803 Pull Request resolved: #19403
1 parent 7debf5c commit 1f8a6e6

13 files changed

Lines changed: 405 additions & 21 deletions

File tree

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ vkapi::DescriptorSet Context::get_descriptor_set(
149149

150150
spec_constants.append(additional_constants);
151151

152+
const uint32_t resolved_required_subgroup_size =
153+
vkapi::resolve_required_subgroup_size(shader_descriptor, adapter_p_);
154+
152155
VkPipeline pipeline = pipeline_cache().retrieve(
153156
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
154157
shader_cache().retrieve(shader_descriptor),
155-
spec_constants});
158+
spec_constants,
159+
resolved_required_subgroup_size});
156160

157161
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
158162

@@ -315,8 +319,14 @@ VkPipeline Context::get_shader_pipeline(
315319

316320
spec_constants.append(additional_constants);
317321

322+
const uint32_t resolved_required_subgroup_size =
323+
vkapi::resolve_required_subgroup_size(shader, adapter_p_);
324+
318325
VkPipeline pipeline = pipeline_cache().retrieve(
319-
{pipeline_layout, shader_cache().retrieve(shader), spec_constants});
326+
{pipeline_layout,
327+
shader_cache().retrieve(shader),
328+
spec_constants,
329+
resolved_required_subgroup_size});
320330

321331
return pipeline;
322332
}

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,9 @@ def layout_declare_buffer(
281281
dtype: str,
282282
precision: str = "PRECISION",
283283
is_scalar_array: bool = True,
284+
vec_size: int = 4,
284285
) -> str:
285-
array_type = buffer_gvec_type(dtype, 4)
286+
array_type = buffer_gvec_type(dtype, vec_size)
286287
if is_scalar_array:
287288
array_type = buffer_scalar_type(dtype)
288289

@@ -341,6 +342,7 @@ def layout_declare_tensor(
341342
storage_type: str,
342343
is_scalar_array: bool = True,
343344
precision: str = "PRECISION",
345+
vec_size: int = 4,
344346
) -> str:
345347
assert storage_type.lower() in ["buffer", "texture3d", "texture2d"]
346348

@@ -357,6 +359,7 @@ def layout_declare_tensor(
357359
dtype,
358360
precision,
359361
is_scalar_array=is_scalar_array,
362+
vec_size=vec_size,
360363
)
361364

362365
# Create image/sampler binding
@@ -785,6 +788,10 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
785788
"generate_variant_forall", None
786789
)
787790

791+
reserved_yaml_keys = {
792+
"generate_variant_forall",
793+
}
794+
788795
for variant in params_dict["shader_variants"]:
789796
default_iterated_params_names = set(
790797
default_iterated_params.keys()
@@ -797,7 +804,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
797804
variant_params_names
798805
- default_iterated_params_names
799806
- params_names
800-
- {"generate_variant_forall"}
807+
- reserved_yaml_keys
801808
)
802809
assert len(invalid_keys) == 0
803810

@@ -813,7 +820,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
813820
for combination in variant_combinations:
814821
default_params_copy = copy.deepcopy(default_params)
815822
for key in variant:
816-
if key != "generate_variant_forall":
823+
if key not in reserved_yaml_keys:
817824
default_params_copy[key] = variant[key]
818825

819826
variant_name = variant["NAME"]
@@ -842,7 +849,8 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
842849
else:
843850
default_params_copy = copy.deepcopy(default_params)
844851
for key in variant:
845-
default_params_copy[key] = variant[key]
852+
if key not in reserved_yaml_keys:
853+
default_params_copy[key] = variant[key]
846854

847855
self.shader_template_params[template_name].append(
848856
default_params_copy
@@ -1026,6 +1034,27 @@ def generate_src_file(shader_paths_pair) -> Tuple[bool, List[str]]:
10261034
print(f"template_file_path: {template_file_path}")
10271035
output_text = preprocess(input_text, codegen_params)
10281036

1037+
# If the shader yaml declared a SUBGROUP_SIZE template parameter,
1038+
# embed it into the generated GLSL as a comment. getShaderInfo()
1039+
# parses it back out alongside TILE_SIZE, WEIGHT_STORAGE, etc.,
1040+
# avoiding a side-channel name -> value map.
1041+
subgroup_size = codegen_params.get("SUBGROUP_SIZE")
1042+
if subgroup_size is not None:
1043+
try:
1044+
subgroup_size_int = int(subgroup_size)
1045+
except (TypeError, ValueError) as e:
1046+
raise RuntimeError(
1047+
f"Shader variant {src_file_name!r} declared "
1048+
f"SUBGROUP_SIZE={subgroup_size!r}, which is not "
1049+
f"parseable as an integer. Fix the SUBGROUP_SIZE "
1050+
f"value in the shader's yaml."
1051+
) from e
1052+
if subgroup_size_int > 0:
1053+
output_text = (
1054+
f"// REQUIRED_SUBGROUP_SIZE = {subgroup_size_int}\n"
1055+
+ output_text
1056+
)
1057+
10291058
included_files = get_glsl_includes(output_text)
10301059

10311060
with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file:
@@ -1184,6 +1213,12 @@ class ShaderInfo:
11841213
requires_integer_dot_product_ext: bool = False
11851214
requires_shader_int64_ext: bool = False
11861215
requires_shader_float64_ext: bool = False
1216+
# Subgroup size requirement (matches the C++ ShaderInfo encoding):
1217+
# 0 = no requirement
1218+
# >0 = literal fixed size; sourced from the shader yaml's
1219+
# `SUBGROUP_SIZE` template parameter (single source of truth for
1220+
# both GLSL substitution and the Vulkan pipeline pin).
1221+
required_subgroup_size: int = 0
11871222

11881223

11891224
def getName(filePath: str) -> str:
@@ -1208,6 +1243,17 @@ def findTileSizes(lineStr: str) -> List[int]:
12081243
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
12091244

12101245

1246+
def isRequiredSubgroupSizeLine(lineStr: str) -> bool:
1247+
return re.search(r"^// REQUIRED_SUBGROUP_SIZE = ", lineStr) is not None
1248+
1249+
1250+
def findRequiredSubgroupSize(lineStr: str) -> int:
1251+
matches = re.search(r"^// REQUIRED_SUBGROUP_SIZE = ([0-9]+)", lineStr)
1252+
if matches is None:
1253+
raise AssertionError("matches is None in findRequiredSubgroupSize")
1254+
return int(matches.group(1))
1255+
1256+
12111257
def isWeightStorageTypeLine(lineStr: str) -> bool:
12121258
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
12131259
return re.search(weight_storage_id, lineStr) is not None
@@ -1281,6 +1327,8 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
12811327
shader_info.layouts.append(determineDescriptorType(line))
12821328
if isTileSizeLine(line):
12831329
shader_info.tile_size = findTileSizes(line)
1330+
if isRequiredSubgroupSizeLine(line):
1331+
shader_info.required_subgroup_size = findRequiredSubgroupSize(line)
12841332
if isWeightStorageTypeLine(line):
12851333
shader_info.weight_storage_type = getWeightStorageType(line)
12861334
if isBiasStorageTypeLine(line):
@@ -1378,6 +1426,7 @@ def to_cpp_str(val: bool):
13781426
to_cpp_str(shader_info.requires_integer_dot_product_ext),
13791427
to_cpp_str(shader_info.requires_shader_int64_ext),
13801428
to_cpp_str(shader_info.requires_shader_float64_ext),
1429+
str(shader_info.required_subgroup_size),
13811430
]
13821431

13831432
shader_info_str = textwrap.indent(
@@ -1406,7 +1455,9 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
14061455

14071456

14081457
def genCppFiles(
1409-
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
1458+
spv_files: Dict[str, str],
1459+
cpp_header_path: str,
1460+
cpp_src_file_path: str,
14101461
) -> None:
14111462
spv_bin_strs = []
14121463
register_shader_info_strs = []

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,10 +825,20 @@ void ComputeGraph::register_pipeline_to_create(
825825

826826
spec_constants.append(spec_vars);
827827

828+
// Resolve any shader-declared required subgroup size into a concrete value
829+
// so the pre-built pipeline matches the one created at dispatch time. The
830+
// shared helper throws ShaderNotSupportedError when the adapter cannot honor
831+
// the requirement; let it propagate so a stale unused pipeline doesn't sit
832+
// in the cache while dispatch later throws on the same shader.
833+
const uint32_t resolved_required_subgroup_size =
834+
vkapi::resolve_required_subgroup_size(
835+
shader_info, context()->adapter_ptr());
836+
828837
const vkapi::ComputePipelineCache::Key desc = {
829838
context()->pipeline_layout_cache().retrieve(shader_layout, pc_offset),
830839
context()->shader_cache().retrieve(shader_info),
831-
spec_constants};
840+
spec_constants,
841+
resolved_required_subgroup_size};
832842

833843
if (context_->pipeline_cache().contains(desc)) {
834844
return;

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ VkDevice create_logical_device(
129129
#ifdef VK_NV_cooperative_matrix2
130130
VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME,
131131
#endif /* VK_NV_cooperative_matrix2 */
132+
#ifdef VK_EXT_subgroup_size_control
133+
VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME,
134+
#endif /* VK_EXT_subgroup_size_control */
132135
};
133136

134137
std::vector<const char*> enabled_device_extensions;
@@ -199,6 +202,19 @@ VkDevice create_logical_device(
199202
extension_list_top = &cooperative_matrix2_features;
200203
#endif /* VK_NV_cooperative_matrix2 */
201204

205+
#ifdef VK_EXT_subgroup_size_control
206+
// Only enable the feature struct if the extension was actually requested
207+
// and the feature flag is set on the physical device. The extension itself
208+
// is filtered into enabled_device_extensions by
209+
// find_requested_device_extensions.
210+
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features{
211+
physical_device.subgroup_size_control_features};
212+
if (physical_device.supports_subgroup_size_control) {
213+
subgroup_size_control_features.pNext = extension_list_top;
214+
extension_list_top = &subgroup_size_control_features;
215+
}
216+
#endif /* VK_EXT_subgroup_size_control */
217+
202218
device_create_info.pNext = extension_list_top;
203219

204220
VkDevice handle = nullptr;
@@ -405,7 +421,7 @@ std::string Adapter::stringize() const {
405421
ss << " deviceType: " << device_type << std::endl;
406422
ss << " deviceName: " << properties.deviceName << std::endl;
407423

408-
#define PRINT_BOOL(value, name) \
424+
#define PRINT_VALUE(value, name) \
409425
ss << " " << std::left << std::setw(36) << #name << value << std::endl;
410426

411427
#define PRINT_PROP(struct, name) \
@@ -452,16 +468,37 @@ std::string Adapter::stringize() const {
452468
#endif /* VK_KHR_8bit_storage */
453469

454470
ss << " Shader 16bit and 8bit Features {" << std::endl;
455-
PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16)
471+
PRINT_VALUE(physical_device_.supports_int16_shader_types, shaderInt16)
456472
#ifdef VK_KHR_shader_float16_int8
457473
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
458474
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
459475
#endif /* VK_KHR_shader_float16_int8 */
460476
ss << " }" << std::endl;
461477

462478
ss << " Shader 64bit Features {" << std::endl;
463-
PRINT_BOOL(physical_device_.supports_int64_shader_types, shaderInt64)
464-
PRINT_BOOL(physical_device_.supports_float64_shader_types, shaderFloat64)
479+
PRINT_VALUE(physical_device_.supports_int64_shader_types, shaderInt64)
480+
PRINT_VALUE(physical_device_.supports_float64_shader_types, shaderFloat64)
481+
ss << " }" << std::endl;
482+
483+
ss << " Subgroup Properties {" << std::endl;
484+
PRINT_VALUE(subgroup_size(), subgroupSize)
485+
PRINT_VALUE(supports_subgroup_compute_basic(), computeSubgroupBasic)
486+
PRINT_VALUE(supports_subgroup_compute_shuffle(), computeSubgroupShuffle)
487+
PRINT_VALUE(supports_subgroup_compute_ballot(), computeSubgroupBallot)
488+
PRINT_VALUE(supports_subgroup_compute_vote(), computeSubgroupVote)
489+
PRINT_VALUE(supports_subgroup_compute_arithmetic(), computeSubgroupArithmetic)
490+
PRINT_VALUE(
491+
supports_subgroup_compute_shuffle_relative(),
492+
computeSubgroupShuffleRelative)
493+
PRINT_VALUE(supports_subgroup_compute_clustered(), computeSubgroupClustered)
494+
PRINT_VALUE(supports_subgroup_compute_quad(), computeSubgroupQuad)
495+
PRINT_VALUE(min_subgroup_size(), minSubgroupSize)
496+
PRINT_VALUE(max_subgroup_size(), maxSubgroupSize)
497+
PRINT_VALUE(supports_subgroup_size_control(), subgroupSizeControl)
498+
PRINT_VALUE(supports_compute_full_subgroups(), computeFullSubgroups)
499+
PRINT_VALUE(
500+
supports_required_subgroup_size_for_compute(),
501+
requiredSubgroupSizeStages_compute)
465502
ss << " }" << std::endl;
466503
467504
#ifdef VK_KHR_shader_integer_dot_product
@@ -614,5 +651,24 @@ std::ostream& operator<<(std::ostream& os, const Adapter& adapter) {
614651
return os;
615652
}
616653
654+
uint32_t resolve_required_subgroup_size(
655+
const ShaderInfo& shader,
656+
Adapter* adapter) {
657+
if (shader.required_subgroup_size == 0u) {
658+
return 0u;
659+
}
660+
if (!adapter->supports_required_subgroup_size_for_compute()) {
661+
throw ShaderNotSupportedError(
662+
shader.kernel_name, VulkanExtension::SUBGROUP_SIZE_CONTROL);
663+
}
664+
const uint32_t resolved = shader.required_subgroup_size;
665+
if (resolved < adapter->min_subgroup_size() ||
666+
resolved > adapter->max_subgroup_size()) {
667+
throw ShaderNotSupportedError(
668+
shader.kernel_name, VulkanExtension::SUBGROUP_SIZE_CONTROL);
669+
}
670+
return resolved;
671+
}
672+
617673
} // namespace vkapi
618674
} // namespace vkcompute

0 commit comments

Comments
 (0)