diff --git a/src/layer/vulkan/gemm_vulkan.cpp b/src/layer/vulkan/gemm_vulkan.cpp index d1b9991d5130..682729c95550 100644 --- a/src/layer/vulkan/gemm_vulkan.cpp +++ b/src/layer/vulkan/gemm_vulkan.cpp @@ -61,7 +61,7 @@ int Gemm_vulkan::create_pipeline(const Option& opt) use_cooperative_matrix = vkdev->info.support_cooperative_matrix() && opt.use_cooperative_matrix && (opt.use_fp16_storage || opt.use_fp16_packed); bool use_bf16_cooperative_matrix = false; - if (vkdev->info.support_bf16_cooperative_matrix() && opt.use_cooperative_matrix && (opt.use_bf16_storage || opt.use_bf16_packed)) + if (vkdev->info.support_bf16_cooperative_matrix() && opt.use_cooperative_matrix && opt.use_bf16_storage) { use_cooperative_matrix = true; use_bf16_cooperative_matrix = true; diff --git a/src/layer/vulkan/sdpa_vulkan.cpp b/src/layer/vulkan/sdpa_vulkan.cpp index 6a3bbe4fedfd..337d2de19c0b 100644 --- a/src/layer/vulkan/sdpa_vulkan.cpp +++ b/src/layer/vulkan/sdpa_vulkan.cpp @@ -61,7 +61,7 @@ int SDPA_vulkan::create_pipeline(const Option& opt) use_cooperative_matrix = vkdev->info.support_cooperative_matrix() && opt.use_cooperative_matrix && (opt.use_fp16_storage || opt.use_fp16_packed); bool use_bf16_cooperative_matrix = false; - if (vkdev->info.support_bf16_cooperative_matrix() && opt.use_cooperative_matrix && (opt.use_bf16_storage || opt.use_bf16_packed)) + if (vkdev->info.support_bf16_cooperative_matrix() && opt.use_cooperative_matrix && opt.use_bf16_storage) { use_cooperative_matrix = true; use_bf16_cooperative_matrix = true; diff --git a/src/layer/vulkan/shader/gemm_cm.comp b/src/layer/vulkan/shader/gemm_cm.comp index 8069dc647f29..e214ffa41347 100644 --- a/src/layer/vulkan/shader/gemm_cm.comp +++ b/src/layer/vulkan/shader/gemm_cm.comp @@ -199,7 +199,7 @@ void main() [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat bias; #else coopmat bias; @@ -265,7 +265,7 @@ void main() if (gn * 4 + 2 < psc(GN)) vb.r = float(buffer_ld1(C_blob_data, ci4.b)); if (gn * 4 + 3 < psc(GN)) vb.g = float(buffer_ld1(C_blob_data, ci4.a)); -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage uvec2 v = uvec2(packBFloat2x16(va), packBFloat2x16(vb)); #else uvec2 v = uvec2(packHalf2x16(va), packHalf2x16(vb)); @@ -286,7 +286,7 @@ void main() #if NCNN_fp16_arithmetic coopMatLoad(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat sum_fp16; #else coopmat sum_fp16; @@ -311,7 +311,7 @@ void main() [[unroll]] for (uint zn = 0; zn < UNROLL_SG_N; zn++) { #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat bias; #else coopmat bias; @@ -456,7 +456,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -583,7 +583,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -728,7 +728,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -856,7 +856,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -1171,7 +1171,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -1298,7 +1298,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -1443,7 +1443,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -1570,7 +1570,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -1633,7 +1633,7 @@ void main() } #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -1923,7 +1923,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -2148,7 +2148,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -2275,7 +2275,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -2427,7 +2427,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -2554,7 +2554,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -2618,7 +2618,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -2833,7 +2833,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -2948,7 +2948,7 @@ void main() const uvec4 ai4m4d2 = (ai4 % 4) / 2; const uvec4 ai4m2 = ai4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gm4.r < psc(GM)) v4a.r = unpackBFloat2x16(A_blob_data[ai4d4.r][ai4m4d2.r])[ai4m2.r]; if (gm4.g < psc(GM)) v4a.g = unpackBFloat2x16(A_blob_data[ai4d4.g][ai4m4d2.g])[ai4m2.g]; if (gm4.b < psc(GM)) v4b.r = unpackBFloat2x16(A_blob_data[ai4d4.b][ai4m4d2.b])[ai4m2.b]; @@ -3097,7 +3097,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -3212,7 +3212,7 @@ void main() const uvec4 bi4m4d2 = (bi4 % 4) / 2; const uvec4 bi4m2 = bi4 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4a.r = unpackBFloat2x16(B_blob_data[bi4d4.r][bi4m4d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4a.g = unpackBFloat2x16(B_blob_data[bi4d4.g][bi4m4d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4b.r = unpackBFloat2x16(B_blob_data[bi4d4.b][bi4m4d2.b])[bi4m2.b]; @@ -3274,7 +3274,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -3423,7 +3423,7 @@ void main() #if NCNN_fp16_arithmetic coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat sum_fp16 = coopmat(sum[zn][zm]); #else coopmat sum_fp16 = coopmat(sum[zn][zm]); @@ -3445,7 +3445,7 @@ void main() #if NCNN_fp16_arithmetic coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat sum_fp16 = coopmat(sum[zn][zm]); #else coopmat sum_fp16 = coopmat(sum[zn][zm]); @@ -3470,7 +3470,7 @@ void main() #if NCNN_fp16_arithmetic coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Nd4 * M), Nd4, gl_CooperativeMatrixLayoutRowMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat sum_fp16 = coopmat(sum[zn][zm]); #else coopmat sum_fp16 = coopmat(sum[zn][zm]); @@ -3492,7 +3492,7 @@ void main() #if NCNN_fp16_arithmetic coopMatStore(sum[zn][zm], tmp_o[sgi], (zn * UNROLL_SG_M + zm) * (Md4 * N), Md4, gl_CooperativeMatrixLayoutColumnMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat sum_fp16 = coopmat(sum[zn][zm]); #else coopmat sum_fp16 = coopmat(sum[zn][zm]); @@ -3606,7 +3606,7 @@ void main() { uvec2 v = tmp_o[sgi][siq]; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); #else afpvec4 vab = afpvec4(unpackHalf2x16(v.r), unpackHalf2x16(v.g)); @@ -3714,7 +3714,7 @@ void main() { uvec2 v = tmp_o[sgi][siq]; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); #else afpvec4 vab = afpvec4(unpackHalf2x16(v.r), unpackHalf2x16(v.g)); diff --git a/src/layer/vulkan/shader/sdpa_cross_cm.comp b/src/layer/vulkan/shader/sdpa_cross_cm.comp index 50524adca9de..f1569f4d90fe 100644 --- a/src/layer/vulkan/shader/sdpa_cross_cm.comp +++ b/src/layer/vulkan/shader/sdpa_cross_cm.comp @@ -253,7 +253,7 @@ void main() const uvec4 ai8m8d2 = (ai8 % 8) / 2; const uvec4 ai8m2 = ai8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; @@ -352,7 +352,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -457,7 +457,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -633,7 +633,7 @@ void main() const uvec4 ai8m8d2 = (ai8 % 8) / 2; const uvec4 ai8m2 = ai8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; @@ -732,7 +732,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -837,7 +837,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -872,7 +872,7 @@ void main() } #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -993,7 +993,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -1131,7 +1131,7 @@ void main() const uvec4 ai8m8d2 = (ai8 % 8) / 2; const uvec4 ai8m2 = ai8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; @@ -1230,7 +1230,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -1335,7 +1335,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -1371,7 +1371,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -1502,7 +1502,7 @@ void main() const uvec4 ai8m8d2 = (ai8 % 8) / 2; const uvec4 ai8m2 = ai8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(A_blob_data[ai4d8.r][ai4m8d2.r])[ai4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(A_blob_data[ai4d8.g][ai4m8d2.g])[ai4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(A_blob_data[ai4d8.b][ai4m8d2.b])[ai4m2.b]; @@ -1591,7 +1591,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gn4.r < psc(GN)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gn4.g < psc(GN)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gn4.b < psc(GN)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -1686,7 +1686,7 @@ void main() const uvec4 bi8m8d2 = (bi8 % 8) / 2; const uvec4 bi8m2 = bi8 % 2; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage if (gk4.r < psc(GK)) v4.r = unpackBFloat2x16(B_blob_data[bi4d8.r][bi4m8d2.r])[bi4m2.r]; if (gk4.g < psc(GK)) v4.g = unpackBFloat2x16(B_blob_data[bi4d8.g][bi4m8d2.g])[bi4m2.g]; if (gk4.b < psc(GK)) v4.b = unpackBFloat2x16(B_blob_data[bi4d8.b][bi4m8d2.b])[bi4m2.b]; @@ -1722,7 +1722,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat A[UNROLL_SG_M]; coopmat B[UNROLL_SG_N]; #else @@ -1832,7 +1832,7 @@ void main() if (gn * 8 + 6 < psc(GN)) vd.r = float(buffer_ld1(attn_mask_blob_data, ci8.b)); if (gn * 8 + 7 < psc(GN)) vd.g = float(buffer_ld1(attn_mask_blob_data, ci8.a)); -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage uvec4 v = uvec4(packBFloat2x16(va), packBFloat2x16(vb), packBFloat2x16(vc), packBFloat2x16(vd)); #else uvec4 v = uvec4(packHalf2x16(va), packHalf2x16(vb), packHalf2x16(vc), packHalf2x16(vd)); @@ -1863,7 +1863,7 @@ void main() #if NCNN_fp16_arithmetic coopMatLoad(mask, tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, gl_CooperativeMatrixLayoutRowMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat mask_fp16; #else coopmat mask_fp16; @@ -1894,7 +1894,7 @@ void main() #if NCNN_fp16_arithmetic coopMatStore(sum[zn][zm], tmp_o, ((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * (Nd8p * M), Nd8p, gl_CooperativeMatrixLayoutRowMajor); #else -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat sum_fp16 = coopmat(sum[zn][zm]); #else coopmat sum_fp16 = coopmat(sum[zn][zm]); @@ -1972,7 +1972,7 @@ void main() { uvec4 v = tmp_o[(((sgi * UNROLL_SG_N + zn) * UNROLL_SG_M + zm) * M + i) * Nd8p + j]; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage afpvec4 vab = afpvec4(unpackBFloat2x16(v.r), unpackBFloat2x16(v.g)); afpvec4 vcd = afpvec4(unpackBFloat2x16(v.b), unpackBFloat2x16(v.a)); #else diff --git a/src/layer/vulkan/shader/sdpa_fa_cm.comp b/src/layer/vulkan/shader/sdpa_fa_cm.comp index 6562fe424730..0841652cf489 100644 --- a/src/layer/vulkan/shader/sdpa_fa_cm.comp +++ b/src/layer/vulkan/shader/sdpa_fa_cm.comp @@ -229,7 +229,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat qm[UNROLL_SG_M]; coopmat km; #else @@ -491,7 +491,7 @@ void main() coopmat a; coopMatLoad(a, tmp_s, ((sgi * UNROLL_SG_M + zm) * UNROLL_P_N + zp) * M * Np, Np, gl_CooperativeMatrixLayoutRowMajor); -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat b = coopmat(a); #else coopmat b = coopmat(a); @@ -503,7 +503,7 @@ void main() barrier(); // load P -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat pm[UNROLL_SG_M][UNROLL_P_N]; #else coopmat pm[UNROLL_SG_M][UNROLL_P_N]; @@ -561,7 +561,7 @@ void main() // load V #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat vm; #else coopmat vm; @@ -593,14 +593,14 @@ void main() [[dont_unroll]] for (; j < dst_seqlen_d16; j++) { #if ncnn_VK_KHR_cooperative_matrix - coopmat qkm[UNROLL_SG_M]; + coopmat qkm[UNROLL_SG_M]; #elif ncnn_VK_NV_cooperative_matrix fcoopmatNV<32, gl_ScopeSubgroup, M, N> qkm[UNROLL_SG_M]; #endif [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { #if ncnn_VK_KHR_cooperative_matrix - qkm[zm] = coopmat(0.f); + qkm[zm] = coopmat(0.f); #elif ncnn_VK_NV_cooperative_matrix qkm[zm] = fcoopmatNV<32, gl_ScopeSubgroup, M, N>(0.f); #endif @@ -654,7 +654,7 @@ void main() barrier(); #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat qm; coopmat km; #else @@ -841,7 +841,7 @@ void main() #if ncnn_VK_KHR_cooperative_matrix [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { - coopmat cc; + coopmat cc; coopMatLoad(cc, smem_correction, (sgi * UNROLL_SG_M + zm) * M, 0, gl_CooperativeMatrixLayoutColumnMajor); [[unroll]] for (uint c = 0; c < MAX_OUT_CHUNKS; c++) @@ -888,10 +888,10 @@ void main() // convert P from fp32 to fp16 [[unroll]] for (uint zm = 0; zm < UNROLL_SG_M; zm++) { - coopmat a; + coopmat a; coopMatLoad(a, tmp_s, (sgi * UNROLL_SG_M + zm) * M * Np, Np, gl_CooperativeMatrixLayoutRowMajor); -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat b = coopmat(a); #else coopmat b = coopmat(a); @@ -902,7 +902,7 @@ void main() barrier(); // load P -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat pm[UNROLL_SG_M]; #else coopmat pm[UNROLL_SG_M]; @@ -951,7 +951,7 @@ void main() // load V #if ncnn_VK_KHR_cooperative_matrix -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage coopmat vm; #else coopmat vm; @@ -1012,7 +1012,7 @@ void main() vec4 v1 = vec4(tmp_o[oi8.r], tmp_o[oi8.g], tmp_o[oi8.b], tmp_o[oi8.a]) * inv_sum; uvec4 out_data; -#if NCNN_bf16_storage || NCNN_bf16_packed +#if NCNN_bf16_storage out_data.x = packBFloat2x16(v0.rg); out_data.y = packBFloat2x16(v0.ba); out_data.z = packBFloat2x16(v1.rg); diff --git a/tests/perf/CMakeLists.txt b/tests/perf/CMakeLists.txt index dea79dcfb566..f2dc897c6180 100644 --- a/tests/perf/CMakeLists.txt +++ b/tests/perf/CMakeLists.txt @@ -37,3 +37,9 @@ ncnn_add_layer_perf(BinaryOp) ncnn_add_layer_perf(Concat) ncnn_add_layer_perf(Sigmoid) ncnn_add_layer_perf(BatchNorm) + +# SDPA perf tests (decode and prefill phases) +if(WITH_LAYER_sdpa) + ncnn_add_perf(sdpa_decode) + ncnn_add_perf(sdpa_prefill) +endif() diff --git a/tests/perf/perf_sdpa_decode.cpp b/tests/perf/perf_sdpa_decode.cpp new file mode 100644 index 000000000000..c670fa53c671 --- /dev/null +++ b/tests/perf/perf_sdpa_decode.cpp @@ -0,0 +1,84 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "perfutil.h" + +// decode phase: src_seqlen=1, with kv_cache and various past_seqlen +static void perf_sdpa_decode(int embed_dim, int num_heads, int num_groups, int past_seqlen) +{ + const int src_seqlen = 1; + const int cur_seqlen = 1; + const int out_embed_dim = embed_dim; + const int dst_seqlen = past_seqlen + cur_seqlen; + + ncnn::ParamDict pd; + pd.set(5, 0); // attn_mask = 0 + pd.set(6, 0.f); // scale = 0 (default 1/sqrt(embed_dim)) + pd.set(7, 1); // kv_cache = 1 + + std::vector weights(0); + + // inputs: q, k, v, past_k, past_v + std::vector inputs(5); + inputs[0] = PerfMat(embed_dim, src_seqlen, num_heads); // q + inputs[1] = PerfMat(embed_dim, cur_seqlen, num_groups); // cur_k + inputs[2] = PerfMat(out_embed_dim, cur_seqlen, num_groups); // cur_v + inputs[3] = PerfMat(embed_dim, past_seqlen, num_groups); // past_k + inputs[4] = PerfMat(out_embed_dim, past_seqlen, num_groups); // past_v + + perf_layer("SDPA", pd, weights, inputs, 3, + "embed=%d heads=%d groups=%d past=%d", + embed_dim, num_heads, num_groups, past_seqlen); +} + +int main() +{ + // typical LLM configurations for decode phase + // format: (embed_dim, num_heads, num_groups, past_seqlen) + + // small model, various cache lengths + perf_sdpa_decode(128, 4, 4, 0); + perf_sdpa_decode(128, 4, 4, 128); + perf_sdpa_decode(128, 4, 4, 512); + perf_sdpa_decode(128, 4, 4, 1024); + perf_sdpa_decode(128, 4, 4, 2048); + + // medium model + perf_sdpa_decode(512, 8, 8, 0); + perf_sdpa_decode(512, 8, 8, 128); + perf_sdpa_decode(512, 8, 8, 512); + perf_sdpa_decode(512, 8, 8, 1024); + perf_sdpa_decode(512, 8, 8, 2048); + + // larger model (e.g., 7B scale) + perf_sdpa_decode(4096, 32, 32, 0); + perf_sdpa_decode(4096, 32, 32, 128); + perf_sdpa_decode(4096, 32, 32, 512); + perf_sdpa_decode(4096, 32, 32, 1024); + perf_sdpa_decode(4096, 32, 32, 2048); + perf_sdpa_decode(4096, 32, 32, 4096); + perf_sdpa_decode(4096, 32, 32, 8192); + + // GQA/MQA configurations + // GQA: num_groups < num_heads + perf_sdpa_decode(4096, 32, 4, 128); + perf_sdpa_decode(4096, 32, 4, 512); + perf_sdpa_decode(4096, 32, 4, 1024); + perf_sdpa_decode(4096, 32, 4, 2048); + perf_sdpa_decode(4096, 32, 4, 4096); + + // MQA: num_groups = 1 + perf_sdpa_decode(4096, 32, 1, 128); + perf_sdpa_decode(4096, 32, 1, 512); + perf_sdpa_decode(4096, 32, 1, 1024); + perf_sdpa_decode(4096, 32, 1, 2048); + perf_sdpa_decode(4096, 32, 1, 4096); + + // very large context lengths + perf_sdpa_decode(4096, 32, 32, 16384); + perf_sdpa_decode(4096, 32, 32, 32768); + perf_sdpa_decode(4096, 32, 4, 16384); + perf_sdpa_decode(4096, 32, 4, 32768); + + return 0; +} diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp new file mode 100644 index 000000000000..6b5c5b08e306 --- /dev/null +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -0,0 +1,89 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "perfutil.h" + +// prefill phase: larger src_seqlen, no kv_cache (past_seqlen=0) +static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) +{ + const int cur_seqlen = src_seqlen; // in prefill, cur_seqlen == src_seqlen + const int out_embed_dim = embed_dim; + + ncnn::ParamDict pd; + pd.set(5, 0); // attn_mask = 0 + pd.set(6, 0.f); // scale = 0 (default 1/sqrt(embed_dim)) + pd.set(7, 0); // kv_cache = 0 (no cache in prefill) + + std::vector weights(0); + + // inputs: q, k, v + std::vector inputs(3); + inputs[0] = PerfMat(embed_dim, src_seqlen, num_heads); // q + inputs[1] = PerfMat(embed_dim, cur_seqlen, num_groups); // k + inputs[2] = PerfMat(out_embed_dim, cur_seqlen, num_groups); // v + + perf_layer("SDPA", pd, weights, inputs, 1, + "embed=%d heads=%d groups=%d seqlen=%d", + embed_dim, num_heads, num_groups, src_seqlen); +} + +int main() +{ + // typical LLM configurations for prefill phase + // format: (embed_dim, num_heads, num_groups, src_seqlen) + + // small model, various sequence lengths + perf_sdpa_prefill(128, 4, 4, 16); + perf_sdpa_prefill(128, 4, 4, 32); + perf_sdpa_prefill(128, 4, 4, 64); + perf_sdpa_prefill(128, 4, 4, 128); + perf_sdpa_prefill(128, 4, 4, 256); + perf_sdpa_prefill(128, 4, 4, 512); + + // medium model + perf_sdpa_prefill(512, 8, 8, 16); + perf_sdpa_prefill(512, 8, 8, 32); + perf_sdpa_prefill(512, 8, 8, 64); + perf_sdpa_prefill(512, 8, 8, 128); + perf_sdpa_prefill(512, 8, 8, 256); + perf_sdpa_prefill(512, 8, 8, 512); + perf_sdpa_prefill(512, 8, 8, 1024); + + // larger model (e.g., 7B scale) + perf_sdpa_prefill(4096, 32, 32, 16); + perf_sdpa_prefill(4096, 32, 32, 32); + perf_sdpa_prefill(4096, 32, 32, 64); + perf_sdpa_prefill(4096, 32, 32, 128); + perf_sdpa_prefill(4096, 32, 32, 256); + perf_sdpa_prefill(4096, 32, 32, 512); + perf_sdpa_prefill(4096, 32, 32, 1024); + perf_sdpa_prefill(4096, 32, 32, 2048); + perf_sdpa_prefill(4096, 32, 32, 4096); + + // GQA/MQA configurations + // GQA: num_groups < num_heads + perf_sdpa_prefill(4096, 32, 4, 128); + perf_sdpa_prefill(4096, 32, 4, 256); + perf_sdpa_prefill(4096, 32, 4, 512); + perf_sdpa_prefill(4096, 32, 4, 1024); + perf_sdpa_prefill(4096, 32, 4, 2048); + perf_sdpa_prefill(4096, 32, 4, 4096); + + // MQA: num_groups = 1 + perf_sdpa_prefill(4096, 32, 1, 128); + perf_sdpa_prefill(4096, 32, 1, 256); + perf_sdpa_prefill(4096, 32, 1, 512); + perf_sdpa_prefill(4096, 32, 1, 1024); + perf_sdpa_prefill(4096, 32, 1, 2048); + perf_sdpa_prefill(4096, 32, 1, 4096); + + // very long sequences + perf_sdpa_prefill(4096, 32, 32, 8192); + perf_sdpa_prefill(4096, 32, 32, 16384); + perf_sdpa_prefill(4096, 32, 32, 32768); + perf_sdpa_prefill(4096, 32, 4, 8192); + perf_sdpa_prefill(4096, 32, 4, 16384); + perf_sdpa_prefill(4096, 32, 4, 32768); + + return 0; +}