Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ vkapi::DescriptorSet Context::get_descriptor_set(

spec_constants.append(additional_constants);

const uint32_t resolved_required_subgroup_size =
vkapi::resolve_required_subgroup_size(shader_descriptor, adapter_p_);

VkPipeline pipeline = pipeline_cache().retrieve(
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
shader_cache().retrieve(shader_descriptor),
spec_constants});
spec_constants,
resolved_required_subgroup_size});

cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);

Expand Down Expand Up @@ -315,8 +319,14 @@ VkPipeline Context::get_shader_pipeline(

spec_constants.append(additional_constants);

const uint32_t resolved_required_subgroup_size =
vkapi::resolve_required_subgroup_size(shader, adapter_p_);

VkPipeline pipeline = pipeline_cache().retrieve(
{pipeline_layout, shader_cache().retrieve(shader), spec_constants});
{pipeline_layout,
shader_cache().retrieve(shader),
spec_constants,
resolved_required_subgroup_size});

return pipeline;
}
Expand Down
61 changes: 56 additions & 5 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,9 @@ def layout_declare_buffer(
dtype: str,
precision: str = "PRECISION",
is_scalar_array: bool = True,
vec_size: int = 4,
) -> str:
array_type = buffer_gvec_type(dtype, 4)
array_type = buffer_gvec_type(dtype, vec_size)
if is_scalar_array:
array_type = buffer_scalar_type(dtype)

Expand Down Expand Up @@ -341,6 +342,7 @@ def layout_declare_tensor(
storage_type: str,
is_scalar_array: bool = True,
precision: str = "PRECISION",
vec_size: int = 4,
) -> str:
assert storage_type.lower() in ["buffer", "texture3d", "texture2d"]

Expand All @@ -357,6 +359,7 @@ def layout_declare_tensor(
dtype,
precision,
is_scalar_array=is_scalar_array,
vec_size=vec_size,
)

# Create image/sampler binding
Expand Down Expand Up @@ -785,6 +788,10 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
"generate_variant_forall", None
)

reserved_yaml_keys = {
"generate_variant_forall",
}

for variant in params_dict["shader_variants"]:
default_iterated_params_names = set(
default_iterated_params.keys()
Expand All @@ -797,7 +804,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
variant_params_names
- default_iterated_params_names
- params_names
- {"generate_variant_forall"}
- reserved_yaml_keys
)
assert len(invalid_keys) == 0

Expand All @@ -813,7 +820,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
for combination in variant_combinations:
default_params_copy = copy.deepcopy(default_params)
for key in variant:
if key != "generate_variant_forall":
if key not in reserved_yaml_keys:
default_params_copy[key] = variant[key]

variant_name = variant["NAME"]
Expand Down Expand Up @@ -842,7 +849,8 @@ def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901
else:
default_params_copy = copy.deepcopy(default_params)
for key in variant:
default_params_copy[key] = variant[key]
if key not in reserved_yaml_keys:
default_params_copy[key] = variant[key]

self.shader_template_params[template_name].append(
default_params_copy
Expand Down Expand Up @@ -1026,6 +1034,27 @@ def generate_src_file(shader_paths_pair) -> Tuple[bool, List[str]]:
print(f"template_file_path: {template_file_path}")
output_text = preprocess(input_text, codegen_params)

# If the shader yaml declared a SUBGROUP_SIZE template parameter,
# embed it into the generated GLSL as a comment. getShaderInfo()
# parses it back out alongside TILE_SIZE, WEIGHT_STORAGE, etc.,
# avoiding a side-channel name -> value map.
subgroup_size = codegen_params.get("SUBGROUP_SIZE")
if subgroup_size is not None:
try:
subgroup_size_int = int(subgroup_size)
except (TypeError, ValueError) as e:
raise RuntimeError(
f"Shader variant {src_file_name!r} declared "
f"SUBGROUP_SIZE={subgroup_size!r}, which is not "
f"parseable as an integer. Fix the SUBGROUP_SIZE "
f"value in the shader's yaml."
) from e
if subgroup_size_int > 0:
output_text = (
f"// REQUIRED_SUBGROUP_SIZE = {subgroup_size_int}\n"
+ output_text
)

included_files = get_glsl_includes(output_text)

with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file:
Expand Down Expand Up @@ -1184,6 +1213,12 @@ class ShaderInfo:
requires_integer_dot_product_ext: bool = False
requires_shader_int64_ext: bool = False
requires_shader_float64_ext: bool = False
# Subgroup size requirement (matches the C++ ShaderInfo encoding):
# 0 = no requirement
# >0 = literal fixed size; sourced from the shader yaml's
# `SUBGROUP_SIZE` template parameter (single source of truth for
# both GLSL substitution and the Vulkan pipeline pin).
required_subgroup_size: int = 0


def getName(filePath: str) -> str:
Expand All @@ -1208,6 +1243,17 @@ def findTileSizes(lineStr: str) -> List[int]:
return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]


def isRequiredSubgroupSizeLine(lineStr: str) -> bool:
return re.search(r"^// REQUIRED_SUBGROUP_SIZE = ", lineStr) is not None


def findRequiredSubgroupSize(lineStr: str) -> int:
matches = re.search(r"^// REQUIRED_SUBGROUP_SIZE = ([0-9]+)", lineStr)
if matches is None:
raise AssertionError("matches is None in findRequiredSubgroupSize")
return int(matches.group(1))


def isWeightStorageTypeLine(lineStr: str) -> bool:
weight_storage_id = r"^ \* WEIGHT_STORAGE = "
return re.search(weight_storage_id, lineStr) is not None
Expand Down Expand Up @@ -1281,6 +1327,8 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901
shader_info.layouts.append(determineDescriptorType(line))
if isTileSizeLine(line):
shader_info.tile_size = findTileSizes(line)
if isRequiredSubgroupSizeLine(line):
shader_info.required_subgroup_size = findRequiredSubgroupSize(line)
if isWeightStorageTypeLine(line):
shader_info.weight_storage_type = getWeightStorageType(line)
if isBiasStorageTypeLine(line):
Expand Down Expand Up @@ -1378,6 +1426,7 @@ def to_cpp_str(val: bool):
to_cpp_str(shader_info.requires_integer_dot_product_ext),
to_cpp_str(shader_info.requires_shader_int64_ext),
to_cpp_str(shader_info.requires_shader_float64_ext),
str(shader_info.required_subgroup_size),
]

shader_info_str = textwrap.indent(
Expand Down Expand Up @@ -1406,7 +1455,9 @@ def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:


def genCppFiles(
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
spv_files: Dict[str, str],
cpp_header_path: str,
cpp_src_file_path: str,
) -> None:
spv_bin_strs = []
register_shader_info_strs = []
Expand Down
12 changes: 11 additions & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,20 @@ void ComputeGraph::register_pipeline_to_create(

spec_constants.append(spec_vars);

// Resolve any shader-declared required subgroup size into a concrete value
// so the pre-built pipeline matches the one created at dispatch time. The
// shared helper throws ShaderNotSupportedError when the adapter cannot honor
// the requirement; let it propagate so a stale unused pipeline doesn't sit
// in the cache while dispatch later throws on the same shader.
const uint32_t resolved_required_subgroup_size =
vkapi::resolve_required_subgroup_size(
shader_info, context()->adapter_ptr());

const vkapi::ComputePipelineCache::Key desc = {
context()->pipeline_layout_cache().retrieve(shader_layout, pc_offset),
context()->shader_cache().retrieve(shader_info),
spec_constants};
spec_constants,
resolved_required_subgroup_size};

if (context_->pipeline_cache().contains(desc)) {
return;
Expand Down
64 changes: 60 additions & 4 deletions backends/vulkan/runtime/vk_api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ VkDevice create_logical_device(
#ifdef VK_NV_cooperative_matrix2
VK_NV_COOPERATIVE_MATRIX_2_EXTENSION_NAME,
#endif /* VK_NV_cooperative_matrix2 */
#ifdef VK_EXT_subgroup_size_control
VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME,
#endif /* VK_EXT_subgroup_size_control */
};

std::vector<const char*> enabled_device_extensions;
Expand Down Expand Up @@ -199,6 +202,19 @@ VkDevice create_logical_device(
extension_list_top = &cooperative_matrix2_features;
#endif /* VK_NV_cooperative_matrix2 */

#ifdef VK_EXT_subgroup_size_control
// Only enable the feature struct if the extension was actually requested
// and the feature flag is set on the physical device. The extension itself
// is filtered into enabled_device_extensions by
// find_requested_device_extensions.
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features{
physical_device.subgroup_size_control_features};
if (physical_device.supports_subgroup_size_control) {
subgroup_size_control_features.pNext = extension_list_top;
extension_list_top = &subgroup_size_control_features;
}
#endif /* VK_EXT_subgroup_size_control */

device_create_info.pNext = extension_list_top;

VkDevice handle = nullptr;
Expand Down Expand Up @@ -405,7 +421,7 @@ std::string Adapter::stringize() const {
ss << " deviceType: " << device_type << std::endl;
ss << " deviceName: " << properties.deviceName << std::endl;

#define PRINT_BOOL(value, name) \
#define PRINT_VALUE(value, name) \
ss << " " << std::left << std::setw(36) << #name << value << std::endl;

#define PRINT_PROP(struct, name) \
Expand Down Expand Up @@ -452,16 +468,37 @@ std::string Adapter::stringize() const {
#endif /* VK_KHR_8bit_storage */

ss << " Shader 16bit and 8bit Features {" << std::endl;
PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16)
PRINT_VALUE(physical_device_.supports_int16_shader_types, shaderInt16)
#ifdef VK_KHR_shader_float16_int8
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
#endif /* VK_KHR_shader_float16_int8 */
ss << " }" << std::endl;

ss << " Shader 64bit Features {" << std::endl;
PRINT_BOOL(physical_device_.supports_int64_shader_types, shaderInt64)
PRINT_BOOL(physical_device_.supports_float64_shader_types, shaderFloat64)
PRINT_VALUE(physical_device_.supports_int64_shader_types, shaderInt64)
PRINT_VALUE(physical_device_.supports_float64_shader_types, shaderFloat64)
ss << " }" << std::endl;

ss << " Subgroup Properties {" << std::endl;
PRINT_VALUE(subgroup_size(), subgroupSize)
PRINT_VALUE(supports_subgroup_compute_basic(), computeSubgroupBasic)
PRINT_VALUE(supports_subgroup_compute_shuffle(), computeSubgroupShuffle)
PRINT_VALUE(supports_subgroup_compute_ballot(), computeSubgroupBallot)
PRINT_VALUE(supports_subgroup_compute_vote(), computeSubgroupVote)
PRINT_VALUE(supports_subgroup_compute_arithmetic(), computeSubgroupArithmetic)
PRINT_VALUE(
supports_subgroup_compute_shuffle_relative(),
computeSubgroupShuffleRelative)
PRINT_VALUE(supports_subgroup_compute_clustered(), computeSubgroupClustered)
PRINT_VALUE(supports_subgroup_compute_quad(), computeSubgroupQuad)
PRINT_VALUE(min_subgroup_size(), minSubgroupSize)
PRINT_VALUE(max_subgroup_size(), maxSubgroupSize)
PRINT_VALUE(supports_subgroup_size_control(), subgroupSizeControl)
PRINT_VALUE(supports_compute_full_subgroups(), computeFullSubgroups)
PRINT_VALUE(
supports_required_subgroup_size_for_compute(),
requiredSubgroupSizeStages_compute)
ss << " }" << std::endl;

#ifdef VK_KHR_shader_integer_dot_product
Expand Down Expand Up @@ -614,5 +651,24 @@ std::ostream& operator<<(std::ostream& os, const Adapter& adapter) {
return os;
}

uint32_t resolve_required_subgroup_size(
const ShaderInfo& shader,
Adapter* adapter) {
if (shader.required_subgroup_size == 0u) {
return 0u;
}
if (!adapter->supports_required_subgroup_size_for_compute()) {
throw ShaderNotSupportedError(
shader.kernel_name, VulkanExtension::SUBGROUP_SIZE_CONTROL);
}
const uint32_t resolved = shader.required_subgroup_size;
if (resolved < adapter->min_subgroup_size() ||
resolved > adapter->max_subgroup_size()) {
throw ShaderNotSupportedError(
shader.kernel_name, VulkanExtension::SUBGROUP_SIZE_CONTROL);
}
return resolved;
}

} // namespace vkapi
} // namespace vkcompute
Loading
Loading