@@ -8582,6 +8582,136 @@ OpCooperativeVectorReduceSumAccumulateNV %array_ptr %offset %f16c
85828582 HasSubstr (" OpCooperativeVectorReduceSumAccumulateNV V type <id> "
85838583 " '28[%v4half]' is not a cooperative vector type." ));
85848584}
8585+
8586+ TEST_F (ValidateMemory, CoopMatMatrixBFloatFAdd) {
8587+ const std::string body =
8588+ R"(
8589+ OpCapability Shader
8590+ OpCapability Float16
8591+ OpCapability BFloat16TypeKHR
8592+ OpCapability BFloat16CooperativeMatrixKHR
8593+ OpCapability VulkanMemoryModel
8594+ OpCapability CooperativeMatrixKHR
8595+ OpExtension "SPV_KHR_bfloat16"
8596+ OpExtension "SPV_KHR_vulkan_memory_model"
8597+ OpExtension "SPV_KHR_cooperative_matrix"
8598+ OpMemoryModel Logical Vulkan
8599+ OpEntryPoint GLCompute %main "main" %_ %__0 %__1
8600+ OpExecutionMode %main LocalSize 32 1 1
8601+ OpDecorate %_arr_bfloat16_uint_64 ArrayStride 2
8602+ OpDecorate %A Block
8603+ OpMemberDecorate %A 0 Offset 0
8604+ OpDecorate %_ Binding 0
8605+ OpDecorate %_ DescriptorSet 0
8606+ OpDecorate %_arr_bfloat16_uint_64_0 ArrayStride 2
8607+ OpDecorate %B Block
8608+ OpMemberDecorate %B 0 Offset 0
8609+ OpDecorate %__0 Binding 1
8610+ OpDecorate %__0 DescriptorSet 0
8611+ OpDecorate %_arr_bfloat16_uint_64_1 ArrayStride 2
8612+ OpDecorate %R Block
8613+ OpMemberDecorate %R 0 Offset 0
8614+ OpDecorate %__1 Binding 2
8615+ OpDecorate %__1 DescriptorSet 0
8616+ %void = OpTypeVoid
8617+ %4 = OpTypeFunction %void
8618+ %bfloat16 = OpTypeFloat 16 BFloat16KHR
8619+ %uint = OpTypeInt 32 0
8620+ %uint_3 = OpConstant %uint 3
8621+ %uint_8 = OpConstant %uint 8
8622+ %uint_0 = OpConstant %uint 0
8623+ %12 = OpTypeCooperativeMatrixKHR %bfloat16 %uint_3 %uint_8 %uint_8 %uint_0
8624+ %_ptr_Function_12 = OpTypePointer Function %12
8625+ %uint_64 = OpConstant %uint 64
8626+ %_arr_bfloat16_uint_64 = OpTypeArray %bfloat16 %uint_64
8627+ %A = OpTypeStruct %_arr_bfloat16_uint_64
8628+ %_ptr_StorageBuffer_A = OpTypePointer StorageBuffer %A
8629+ %_ = OpVariable %_ptr_StorageBuffer_A StorageBuffer
8630+ %int = OpTypeInt 32 1
8631+ %int_0 = OpConstant %int 0
8632+ %_ptr_StorageBuffer_bfloat16 = OpTypePointer StorageBuffer %bfloat16
8633+ %_arr_bfloat16_uint_64_0 = OpTypeArray %bfloat16 %uint_64
8634+ %B = OpTypeStruct %_arr_bfloat16_uint_64_0
8635+ %_ptr_StorageBuffer_B = OpTypePointer StorageBuffer %B
8636+ %__0 = OpVariable %_ptr_StorageBuffer_B StorageBuffer
8637+ %v3uint = OpTypeVector %uint 3
8638+ %uint_32 = OpConstant %uint 32
8639+ %uint_1 = OpConstant %uint 1
8640+ %35 = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1
8641+ %_arr_bfloat16_uint_64_1 = OpTypeArray %bfloat16 %uint_64
8642+ %R = OpTypeStruct %_arr_bfloat16_uint_64_1
8643+ %_ptr_StorageBuffer_R = OpTypePointer StorageBuffer %R
8644+ %__1 = OpVariable %_ptr_StorageBuffer_R StorageBuffer
8645+ %main = OpFunction %void None %4
8646+ %6 = OpLabel
8647+ %matX = OpVariable %_ptr_Function_12 Function
8648+ %matY = OpVariable %_ptr_Function_12 Function
8649+ %23 = OpAccessChain %_ptr_StorageBuffer_bfloat16 %_ %int_0 %uint_0
8650+ %24 = OpCooperativeMatrixLoadKHR %12 %23 %int_0 %uint_8 None
8651+ OpStore %matX %24
8652+ %30 = OpAccessChain %_ptr_StorageBuffer_bfloat16 %__0 %int_0 %uint_0
8653+ %31 = OpCooperativeMatrixLoadKHR %12 %30 %int_0 %uint_8 None
8654+ OpStore %matY %31
8655+ %32 = OpLoad %12 %matX
8656+ %33 = OpLoad %12 %matY
8657+ %34 = OpFAdd %12 %32 %33
8658+ OpReturn
8659+ OpFunctionEnd
8660+ )" ;
8661+
8662+ CompileSuccessfully (body.c_str (), SPV_ENV_VULKAN_1_3);
8663+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions (SPV_ENV_VULKAN_1_3));
8664+ EXPECT_THAT (getDiagnosticString (),
8665+ HasSubstr (" FAdd doesn't support BFloat16 type" ));
8666+ }
8667+
8668+ TEST_F (ValidateMemory, CoopMatMatrixFloat8FAdd) {
8669+ const std::string body =
8670+ R"(
8671+ OpCapability Shader
8672+ OpCapability Float8EXT
8673+ OpCapability Float8CooperativeMatrixEXT
8674+ OpCapability VulkanMemoryModel
8675+ OpCapability CooperativeMatrixKHR
8676+ OpExtension "SPV_EXT_float8"
8677+ OpExtension "SPV_KHR_cooperative_matrix"
8678+ OpExtension "SPV_KHR_vulkan_memory_model"
8679+ OpMemoryModel Logical Vulkan
8680+ OpEntryPoint GLCompute %main "main"
8681+ OpExecutionMode %main LocalSize 32 1 1
8682+ OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
8683+ %void = OpTypeVoid
8684+ %4 = OpTypeFunction %void
8685+ %fp8e4m3 = OpTypeFloat 8 Float8E4M3EXT
8686+ %uint = OpTypeInt 32 0
8687+ %uint_3 = OpConstant %uint 3
8688+ %uint_16 = OpConstant %uint 16
8689+ %uint_0 = OpConstant %uint 0
8690+ %12 = OpTypeCooperativeMatrixKHR %fp8e4m3 %uint_3 %uint_16 %uint_16 %uint_0
8691+ %_ptr_Function_12 = OpTypePointer Function %12
8692+ %v3uint = OpTypeVector %uint 3
8693+ %uint_32 = OpConstant %uint 32
8694+ %uint_1 = OpConstant %uint 1
8695+ %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_32 %uint_1 %uint_1
8696+ %main = OpFunction %void None %4
8697+ %6 = OpLabel
8698+ %matR = OpVariable %_ptr_Function_12 Function
8699+ %matX = OpVariable %_ptr_Function_12 Function
8700+ %matY = OpVariable %_ptr_Function_12 Function
8701+ %16 = OpLoad %12 %matX
8702+ %18 = OpLoad %12 %matY
8703+ %19 = OpFAdd %12 %16 %18
8704+ OpStore %matR %19
8705+ OpReturn
8706+ OpFunctionEnd
8707+ )" ;
8708+
8709+ CompileSuccessfully (body.c_str (), SPV_ENV_VULKAN_1_3);
8710+ ASSERT_EQ (SPV_ERROR_INVALID_DATA, ValidateInstructions (SPV_ENV_VULKAN_1_3));
8711+ EXPECT_THAT (getDiagnosticString (),
8712+ HasSubstr (" FAdd doesn't support FP8 E4M3/E5M2 types" ));
8713+ }
8714+
85858715} // namespace
85868716} // namespace val
85878717} // namespace spvtools
0 commit comments