diff --git a/src/layer/vulkan/matmul_vulkan.cpp b/src/layer/vulkan/matmul_vulkan.cpp new file mode 100644 index 000000000000..b870de76d489 --- /dev/null +++ b/src/layer/vulkan/matmul_vulkan.cpp @@ -0,0 +1,381 @@ +// Copyright 2026 MYQ +// SPDX-License-Identifier: BSD-3-Clause + +#include "matmul_vulkan.h" + +#include +#include + +#include "layer_shader_type.h" + +namespace ncnn { + +MatMul_vulkan::MatMul_vulkan() +{ + support_vulkan = true; + support_vulkan_packing = true; + support_vulkan_any_packing = true; + + pipeline_matmul = 0; + pipeline_matmul_sg = 0; + pipeline_matmul_cm = 0; + + use_subgroup_ops = false; + + use_cooperative_matrix = false; + coopmat_M = 0; + coopmat_N = 0; + coopmat_K = 0; + coopmat_subgroup_size = 0; + UNROLL_SG_M = 1; + UNROLL_SG_N = 1; + UNROLL_SG_K = 1; + UNROLL_WG_M = 1; + UNROLL_WG_N = 1; +} + +int MatMul_vulkan::create_pipeline(const Option& opt) +{ + std::vector matmul_specializations(1); + matmul_specializations[0].i = transB; + + pipeline_matmul = new Pipeline(vkdev); + if (opt.use_shader_local_memory) + { + pipeline_matmul->set_local_size_xyz(8, 8, 1); + } + else + { + Mat local_size_xyz; + pipeline_matmul->set_optimal_local_size_xyz(local_size_xyz); + } + + int ret = pipeline_matmul->create(LayerShaderType::matmul, opt, matmul_specializations); + if (ret != 0) + { + destroy_pipeline(opt); + return ret; + } + + const int subgroup_size = vkdev->info.subgroup_size(); + const uint32_t subgroup_features = vkdev->info.support_subgroup_ops(); + const bool support_subgroup_shuffle = (subgroup_features & (VK_SUBGROUP_FEATURE_BASIC_BIT | VK_SUBGROUP_FEATURE_SHUFFLE_BIT)) == (VK_SUBGROUP_FEATURE_BASIC_BIT | VK_SUBGROUP_FEATURE_SHUFFLE_BIT); + + use_subgroup_ops = opt.use_subgroup_ops && support_subgroup_shuffle; + if (subgroup_size < 4 || subgroup_size > 128) + { + // sanitize weird subgroup_size + use_subgroup_ops = false; + } + + if (use_subgroup_ops) + { + std::vector sg_specializations(2); + sg_specializations[0].i = transB; + sg_specializations[1].u32 = subgroup_size; + + pipeline_matmul_sg = new Pipeline(vkdev); + pipeline_matmul_sg->set_subgroup_size(subgroup_size); + pipeline_matmul_sg->set_local_size_xyz(subgroup_size, 1, 1); + ret = pipeline_matmul_sg->create(LayerShaderType::matmul_sg, opt, sg_specializations); + if (ret != 0) + { + delete pipeline_matmul_sg; + pipeline_matmul_sg = 0; + use_subgroup_ops = false; + } + } + + if (vkdev->info.support_cooperative_matrix() && opt.use_cooperative_matrix && (opt.use_fp16_storage || opt.use_fp16_packed)) + { + int M = 1024; + int N = 1024; + int K = 1024; + vkdev->info.get_optimal_cooperative_matrix_mnk(M, N, K, VK_COMPONENT_TYPE_FLOAT16_KHR, opt.use_fp16_arithmetic ? VK_COMPONENT_TYPE_FLOAT16_KHR : VK_COMPONENT_TYPE_FLOAT32_KHR, VK_SCOPE_SUBGROUP_KHR, coopmat_M, coopmat_N, coopmat_K, coopmat_subgroup_size); + if (coopmat_M > 0 && coopmat_N > 0 && coopmat_K > 0 && coopmat_subgroup_size >= 4 && coopmat_subgroup_size <= 128) + { + use_cooperative_matrix = true; + + UNROLL_SG_M = std::min((M + coopmat_M - 1) / coopmat_M, 2); + UNROLL_SG_N = std::min((N + coopmat_N - 1) / coopmat_N, 2); + UNROLL_SG_K = std::min((K + coopmat_K - 1) / coopmat_K, 2); + UNROLL_WG_M = std::min((M + coopmat_M * UNROLL_SG_M - 1) / (coopmat_M * UNROLL_SG_M), 2); + UNROLL_WG_N = std::min((N + coopmat_N * UNROLL_SG_N - 1) / (coopmat_N * UNROLL_SG_N), 2); + + std::vector cm_specializations(5); + cm_specializations[0].i = transB; + cm_specializations[1].u32 = coopmat_M; + cm_specializations[2].u32 = coopmat_N; + cm_specializations[3].u32 = coopmat_K; + cm_specializations[4].u32 = coopmat_subgroup_size; + + pipeline_matmul_cm = new Pipeline(vkdev); + pipeline_matmul_cm->set_subgroup_size(coopmat_subgroup_size); + pipeline_matmul_cm->set_local_size_xyz(coopmat_subgroup_size, 1, 1); + ret = pipeline_matmul_cm->create(LayerShaderType::matmul_cm, opt, cm_specializations); + if (ret != 0) + { + delete pipeline_matmul_cm; + pipeline_matmul_cm = 0; + + use_cooperative_matrix = false; + coopmat_M = 0; + coopmat_N = 0; + coopmat_K = 0; + coopmat_subgroup_size = 0; + UNROLL_SG_M = 1; + UNROLL_SG_N = 1; + UNROLL_SG_K = 1; + UNROLL_WG_M = 1; + UNROLL_WG_N = 1; + } + } + } + + return 0; +} + +int MatMul_vulkan::destroy_pipeline(const Option& /*opt*/) +{ + delete pipeline_matmul; + pipeline_matmul = 0; + + delete pipeline_matmul_sg; + pipeline_matmul_sg = 0; + + delete pipeline_matmul_cm; + pipeline_matmul_cm = 0; + + use_subgroup_ops = false; + + use_cooperative_matrix = false; + coopmat_M = 0; + coopmat_N = 0; + coopmat_K = 0; + coopmat_subgroup_size = 0; + UNROLL_SG_M = 1; + UNROLL_SG_N = 1; + UNROLL_SG_K = 1; + UNROLL_WG_M = 1; + UNROLL_WG_N = 1; + + return 0; +} + +int MatMul_vulkan::forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const +{ + const VkMat& A0 = bottom_blobs[0]; + const VkMat& B0 = bottom_blobs[1]; + + VkMat A; + VkMat B; + vkdev->convert_packing(A0, A, 1, cmd, opt); + vkdev->convert_packing(B0, B, 1, cmd, opt); + + const int Adims = A.dims; + const int Bdims = B.dims; + + if (Adims < 1 || Adims > 4 || Bdims < 1 || Bdims > 4) + { + NCNN_LOGE("unsupported matmul dims A=%d B=%d", Adims, Bdims); + return -1; + } + + const int max_ABdims = std::max(Adims, Bdims); + + // For max rank 4, MatMul cpu semantics reshape 3d tensor (w,h,c) to (w,h,d=c,c=1). + const bool A_reshape_3d_to_4d = max_ABdims == 4 && Adims == 3; + const bool B_reshape_3d_to_4d = max_ABdims == 4 && Bdims == 3; + + const int A_batch_c = A_reshape_3d_to_4d ? 1 : (Adims >= 3 ? A.c : 1); + const int A_batch_d = A_reshape_3d_to_4d ? A.c : (Adims == 4 ? A.d : 1); + const int B_batch_c = B_reshape_3d_to_4d ? 1 : (Bdims >= 3 ? B.c : 1); + const int B_batch_d = B_reshape_3d_to_4d ? B.c : (Bdims == 4 ? B.d : 1); + + const int A_dstep = A_reshape_3d_to_4d ? (int)A.cstep : A.w * A.h; + const int B_dstep = B_reshape_3d_to_4d ? (int)B.cstep : B.w * B.h; + + const int M = Adims == 1 ? 1 : A.h; + const int K = A.w; + const int N = Bdims == 1 ? 1 : (transB ? B.h : B.w); + + const int BK = Bdims == 1 ? B.w : (transB ? B.w : B.h); + if (K != BK) + { + NCNN_LOGE("matmul K mismatch A=%d B=%d transB=%d", K, BK, transB); + return -1; + } + + const bool batch_c_compatible = A_batch_c == B_batch_c || A_batch_c == 1 || B_batch_c == 1; + const bool batch_d_compatible = A_batch_d == B_batch_d || A_batch_d == 1 || B_batch_d == 1; + if (!batch_c_compatible || !batch_d_compatible) + { + NCNN_LOGE("matmul batch mismatch A(c=%d d=%d) B(c=%d d=%d)", A_batch_c, A_batch_d, B_batch_c, B_batch_d); + return -1; + } + + const int A_layout = Adims == 1 ? 2 : (Adims <= 2 ? 0 : 1); + const int B_layout = Bdims == 1 ? 2 : (Bdims <= 2 ? 0 : 1); + + int out_layout = 0; + int out_batch_c = 1; + int out_batch_d = 1; + + VkMat& top_blob = top_blobs[0]; + const size_t elemsize = A.elemsize; + + if (Adims == 1 && Bdims == 1) + { + out_layout = 1; + top_blob.create(1, elemsize, opt.blob_vkallocator); + } + else if (Adims == 2 && Bdims == 2) + { + out_layout = 0; + top_blob.create(N, M, elemsize, opt.blob_vkallocator); + } + else if (Adims == 1 && Bdims == 2) + { + out_layout = 1; + top_blob.create(N, elemsize, opt.blob_vkallocator); + } + else if (Adims == 2 && Bdims == 1) + { + out_layout = 2; + top_blob.create(M, elemsize, opt.blob_vkallocator); + } + else if (Adims == 1 && Bdims > 2) + { + out_layout = 4; + + if (Bdims == 3) + { + out_batch_d = B_batch_d * B_batch_c; + out_batch_c = 1; + top_blob.create(N, out_batch_d, elemsize, opt.blob_vkallocator); + } + else + { + out_batch_d = B_batch_d; + out_batch_c = B_batch_c; + top_blob.create(N, out_batch_d, out_batch_c, elemsize, opt.blob_vkallocator); + } + } + else if (Adims > 2 && Bdims == 1) + { + out_layout = 5; + + if (Adims == 3) + { + out_batch_d = A_batch_d * A_batch_c; + out_batch_c = 1; + top_blob.create(M, out_batch_d, elemsize, opt.blob_vkallocator); + } + else + { + out_batch_d = A_batch_d; + out_batch_c = A_batch_c; + top_blob.create(M, out_batch_d, out_batch_c, elemsize, opt.blob_vkallocator); + } + } + else if (max_ABdims == 3) + { + out_layout = 3; + out_batch_d = 1; + out_batch_c = std::max(A_batch_c, B_batch_c); + top_blob.create(N, M, out_batch_c, elemsize, opt.blob_vkallocator); + } + else if (max_ABdims == 4) + { + out_layout = 3; + out_batch_d = std::max(A_batch_d, B_batch_d); + out_batch_c = std::max(A_batch_c, B_batch_c); + top_blob.create(N, M, out_batch_d, out_batch_c, elemsize, opt.blob_vkallocator); + } + else + { + NCNN_LOGE("impossible matmul %d %d", Adims, Bdims); + return -1; + } + + if (top_blob.empty()) + return -100; + + std::vector bindings(3); + bindings[0] = top_blob; + bindings[1] = A; + bindings[2] = B; + + std::vector constants(23); + constants[0].i = M; + constants[1].i = N; + constants[2].i = K; + constants[3].i = out_batch_c * out_batch_d; + constants[4].i = A_layout; + constants[5].i = B_layout; + constants[6].i = out_layout; + + constants[7].i = A.w; + constants[8].i = (int)A.cstep; + constants[9].i = A_dstep; + constants[10].i = A_batch_c; + constants[11].i = A_batch_d; + + // B_hstep is the physical row stride (w) of B in memory. + // Indexing formula in shader switches on transB to access B[k, n] or B[n, k]. + constants[12].i = B.w; + constants[13].i = (int)B.cstep; + constants[14].i = B_dstep; + constants[15].i = B_batch_c; + constants[16].i = B_batch_d; + + constants[17].i = top_blob.dims >= 2 ? top_blob.w : 0; + constants[18].i = top_blob.dims >= 3 ? (int)top_blob.cstep : 0; + constants[19].i = top_blob.w * (top_blob.dims >= 2 ? top_blob.h : 1); + constants[20].i = out_batch_c; + constants[21].i = out_batch_d; + constants[22].i = transB; + + const int batch = out_batch_c * out_batch_d; + + Pipeline* selected_pipeline = pipeline_matmul; + + VkMat dispatcher; + dispatcher.w = N; + dispatcher.h = M; + dispatcher.c = batch; + + if (pipeline_matmul_cm && use_cooperative_matrix) + { + const bool fp16_input = A.elemsize == 2u && B.elemsize == 2u; + const int64_t work = (int64_t)M * N * K; + const int64_t tile = (int64_t)coopmat_M * coopmat_N * coopmat_K; + + if (fp16_input && work >= tile) + { + // NCNN_LOGE("pipeline_matmul_cm"); + selected_pipeline = pipeline_matmul_cm; + dispatcher.w = ((N + coopmat_N - 1) / coopmat_N) * coopmat_subgroup_size; + dispatcher.h = (M + coopmat_M - 1) / coopmat_M; + dispatcher.c = batch; + } + } + + if (selected_pipeline == pipeline_matmul && pipeline_matmul_sg && use_subgroup_ops) + { + const int64_t work = (int64_t)M * N * K; + if (work >= 512) + { + // NCNN_LOGE("pipeline_matmul_sg"); + selected_pipeline = pipeline_matmul_sg; + dispatcher.w = N * vkdev->info.subgroup_size(); + dispatcher.h = M; + dispatcher.c = batch; + } + } + cmd.record_pipeline(selected_pipeline, bindings, constants, dispatcher); + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/vulkan/matmul_vulkan.h b/src/layer/vulkan/matmul_vulkan.h new file mode 100644 index 000000000000..d0dfaeb5567d --- /dev/null +++ b/src/layer/vulkan/matmul_vulkan.h @@ -0,0 +1,45 @@ +// Copyright 2026 MYQ +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef LAYER_MATMUL_VULKAN_H +#define LAYER_MATMUL_VULKAN_H + +#include "matmul.h" + +namespace ncnn { + +class MatMul_vulkan : public MatMul +{ +public: + MatMul_vulkan(); + + virtual int create_pipeline(const Option& opt); + virtual int destroy_pipeline(const Option& opt); + + using MatMul::forward; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, VkCompute& cmd, const Option& opt) const; + +public: + Pipeline* pipeline_matmul; + Pipeline* pipeline_matmul_sg; + Pipeline* pipeline_matmul_cm; + + // subgroup + bool use_subgroup_ops; + + // cooperative matrix + bool use_cooperative_matrix; + int coopmat_M; + int coopmat_N; + int coopmat_K; + int coopmat_subgroup_size; + int UNROLL_SG_M; + int UNROLL_SG_N; + int UNROLL_SG_K; + int UNROLL_WG_M; + int UNROLL_WG_N; +}; + +} // namespace ncnn + +#endif // LAYER_MATMUL_VULKAN_H diff --git a/src/layer/vulkan/shader/matmul.comp b/src/layer/vulkan/shader/matmul.comp new file mode 100644 index 000000000000..47f997e2e0dc --- /dev/null +++ b/src/layer/vulkan/shader/matmul.comp @@ -0,0 +1,238 @@ +// Copyright 2026 MYQ +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +layout(constant_id = 0) const int transB = 0; + +layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout(binding = 1) readonly buffer A_blob { sfp A_blob_data[]; }; +layout(binding = 2) readonly buffer B_blob { sfp B_blob_data[]; }; + +layout(push_constant) uniform parameter +{ + int M; + int N; + int K; + int batch; + + int A_layout; + int B_layout; + int out_layout; + + int A_hstep; + int A_cstep; + int A_dstep; + int A_batch_c; + int A_batch_d; + + int B_hstep; + int B_cstep; + int B_dstep; + int B_batch_c; + int B_batch_d; + + int out_hstep; + int out_cstep; + int out_dstep; + int out_batch_c; + int out_batch_d; + + int transB_dynamic; +} p; + +#if NCNN_shader_local_memory +#define TILE_K 8 +shared lfp shA[8][TILE_K]; +shared lfp shB[TILE_K][8]; +#endif + +void main() +{ + const uint gx = gl_GlobalInvocationID.x; + const uint gy = gl_GlobalInvocationID.y; + const uint gz = gl_GlobalInvocationID.z; + const bool valid = gx < uint(p.N) && gy < uint(p.M) && gz < uint(p.batch); + +#if !NCNN_shader_local_memory + if (!valid) + return; +#endif + + uint out_c_idx = 0; + uint out_d_idx = 0; + if (p.out_batch_c > 1 || p.out_batch_d > 1) + { + out_c_idx = p.out_batch_c == 1 ? 0 : gz / uint(p.out_batch_d); + out_d_idx = p.out_batch_d == 1 ? 0 : gz % uint(p.out_batch_d); + } + + uint a_c_idx = 0; + uint a_d_idx = 0; + if (p.A_batch_c > 1 || p.A_batch_d > 1) + { + if (p.out_batch_c == 1 && p.A_batch_c > 1 && p.A_batch_d == 1) + a_c_idx = gz; + else + a_c_idx = p.A_batch_c == 1 ? 0 : out_c_idx; + + a_d_idx = p.A_batch_d == 1 ? 0 : out_d_idx; + } + + uint b_c_idx = 0; + uint b_d_idx = 0; + if (p.B_batch_c > 1 || p.B_batch_d > 1) + { + if (p.out_batch_c == 1 && p.B_batch_c > 1 && p.B_batch_d == 1) + b_c_idx = gz; + else + b_c_idx = p.B_batch_c == 1 ? 0 : out_c_idx; + + b_d_idx = p.B_batch_d == 1 ? 0 : out_d_idx; + } + + const bool tb = transB == 1 || p.transB_dynamic == 1; + afp sum = afp(0.f); + +#if NCNN_shader_local_memory + const uint lx = gl_LocalInvocationID.x; + const uint ly = gl_LocalInvocationID.y; + const bool valid_batch = gz < uint(p.batch); + const bool valid_a = gy < uint(p.M) && valid_batch; + const bool valid_b = gx < uint(p.N) && valid_batch; + const uint baseA = p.A_layout == 1 ? (a_c_idx * uint(p.A_cstep) + a_d_idx * uint(p.A_dstep) + gy * uint(p.A_hstep)) + : (p.A_layout == 0 ? gy * uint(p.A_hstep) : 0u); + const uint baseB = p.B_layout == 1 ? (b_c_idx * uint(p.B_cstep) + b_d_idx * uint(p.B_dstep)) + : 0u; + const uint gx_B_hstep = gx * uint(p.B_hstep); + + for (uint k0 = 0; k0 < uint(p.K); k0 += TILE_K) + { + afp a = afp(0.f); + afp b = afp(0.f); + + const uint ka = k0 + lx; + const uint kb = k0 + ly; + + if (valid_a && ka < uint(p.K)) + { + uint ai = 0; + if (p.A_layout == 2) + { + ai = ka; + } + else + { + ai = baseA + ka; + } + + a = afp(buffer_ld1(A_blob_data, ai)); + } + + if (valid_b && kb < uint(p.K)) + { + const uint kb_B_hstep = kb * uint(p.B_hstep); + uint bi = 0; + if (p.B_layout == 2) + { + bi = kb; + } + else if (p.B_layout == 1) + { + if (tb) + bi = baseB + gx_B_hstep + kb; + else + bi = baseB + kb_B_hstep + gx; + } + else + { + if (tb) + bi = gx_B_hstep + kb; + else + bi = kb_B_hstep + gx; + } + + b = afp(buffer_ld1(B_blob_data, bi)); + } + + shA[ly][lx] = afp2lfp(a); + shB[ly][lx] = afp2lfp(b); + + barrier(); + + const uint kend = min(uint(TILE_K), uint(p.K) - k0); + for (uint kk = 0; kk < kend; kk++) + { + sum += lfp2afp(shA[ly][kk]) * lfp2afp(shB[kk][lx]); + } + + barrier(); + } +#else + const uint baseA = p.A_layout == 1 ? (a_c_idx * uint(p.A_cstep) + a_d_idx * uint(p.A_dstep) + gy * uint(p.A_hstep)) + : (p.A_layout == 0 ? gy * uint(p.A_hstep) : 0u); + const uint baseB = p.B_layout == 1 ? (b_c_idx * uint(p.B_cstep) + b_d_idx * uint(p.B_dstep)) + : 0u; + const uint gx_B_hstep = gx * uint(p.B_hstep); + + for (uint k = 0; k < uint(p.K); k++) + { + uint ai = 0; + ai = p.A_layout == 2 ? k : (baseA + k); + + const uint k_B_hstep = k * uint(p.B_hstep); + uint bi = 0; + if (p.B_layout == 2) + { + bi = k; + } + else if (p.B_layout == 1) + { + if (tb) + bi = baseB + gx_B_hstep + k; + else + bi = baseB + k_B_hstep + gx; + } + else + { + if (tb) + bi = gx_B_hstep + k; + else + bi = k_B_hstep + gx; + } + + sum += afp(buffer_ld1(A_blob_data, ai)) * afp(buffer_ld1(B_blob_data, bi)); + } +#endif + + if (!valid) + return; + + uint oi = 0; + if (p.out_layout == 1) + { + oi = gx; + } + else if (p.out_layout == 2) + { + oi = gy; + } + else if (p.out_layout == 3) + { + oi = out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_dstep) + gy * uint(p.out_hstep) + gx; + } + else if (p.out_layout == 4) + { + oi = out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_hstep) + gx; + } + else if (p.out_layout == 5) + { + oi = out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_hstep) + gy; + } + else + { + oi = gy * uint(p.out_hstep) + gx; + } + + buffer_st1(top_blob_data, oi, sum); +} diff --git a/src/layer/vulkan/shader/matmul_cm.comp b/src/layer/vulkan/shader/matmul_cm.comp new file mode 100644 index 000000000000..9cddab716a80 --- /dev/null +++ b/src/layer/vulkan/shader/matmul_cm.comp @@ -0,0 +1,244 @@ +// Copyright 2026 MYQ +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_memory_scope_semantics : require +#extension GL_EXT_shader_explicit_arithmetic_types : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#if ncnn_VK_KHR_cooperative_matrix +#extension GL_KHR_cooperative_matrix : require +#elif ncnn_VK_NV_cooperative_matrix +#extension GL_NV_cooperative_matrix : require +#endif + +layout(constant_id = 0) const int transB = 0; +layout(constant_id = 1) const uint CM_M = 8; +layout(constant_id = 2) const uint CM_N = 8; +layout(constant_id = 3) const uint CM_K = 8; +layout(constant_id = 4) const uint subgroup_size = 32; + +layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout(binding = 1) readonly buffer A_blob { sfp A_blob_data[]; }; +layout(binding = 2) readonly buffer B_blob { sfp B_blob_data[]; }; + +layout(push_constant) uniform parameter +{ + int M; + int N; + int K; + int batch; + + int A_layout; + int B_layout; + int out_layout; + + int A_hstep; + int A_cstep; + int A_dstep; + int A_batch_c; + int A_batch_d; + + int B_hstep; + int B_cstep; + int B_dstep; + int B_batch_c; + int B_batch_d; + + int out_hstep; + int out_cstep; + int out_dstep; + int out_batch_c; + int out_batch_d; + + int transB_dynamic; +} p; + +shared float16_t shA[CM_M * CM_K]; +shared float16_t shB[CM_K * CM_N]; +shared afp shO[CM_M * CM_N]; + +uint get_a_index(uint m, uint k, uint a_c_idx, uint a_d_idx) +{ + if (p.A_layout == 2) + return k; + + if (p.A_layout == 1) + return a_c_idx * uint(p.A_cstep) + a_d_idx * uint(p.A_dstep) + m * uint(p.A_hstep) + k; + + return m * uint(p.A_hstep) + k; +} + +uint get_b_index(uint k, uint n, uint b_c_idx, uint b_d_idx) +{ + if (p.B_layout == 2) + return k; + + const bool tb = transB == 1 || p.transB_dynamic == 1; + + if (p.B_layout == 1) + { + const uint base = b_c_idx * uint(p.B_cstep) + b_d_idx * uint(p.B_dstep); + return tb ? (base + n * uint(p.B_hstep) + k) : (base + k * uint(p.B_hstep) + n); + } + + return tb ? (n * uint(p.B_hstep) + k) : (k * uint(p.B_hstep) + n); +} + +uint get_o_index(uint m, uint n, uint out_c_idx, uint out_d_idx) +{ + if (p.out_layout == 1) + return n; + + if (p.out_layout == 2) + return m; + + if (p.out_layout == 3) + return out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_dstep) + m * uint(p.out_hstep) + n; + + if (p.out_layout == 4) + return out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_hstep) + n; + + if (p.out_layout == 5) + return out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_hstep) + m; + + return m * uint(p.out_hstep) + n; +} + +void main() +{ + const uint lane = gl_SubgroupInvocationID; + + const uint tile_n = gl_GlobalInvocationID.x / subgroup_size; + const uint tile_m = gl_GlobalInvocationID.y; + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint m0 = tile_m * CM_M; + const uint n0 = tile_n * CM_N; + + if (m0 >= uint(p.M) || n0 >= uint(p.N) || batch_idx >= uint(p.batch)) + return; + + uint out_c_idx = 0; + uint out_d_idx = 0; + if (p.out_batch_c > 1 || p.out_batch_d > 1) + { + out_c_idx = p.out_batch_c == 1 ? 0 : batch_idx / uint(p.out_batch_d); + out_d_idx = p.out_batch_d == 1 ? 0 : batch_idx % uint(p.out_batch_d); + } + + uint a_c_idx = 0; + uint a_d_idx = 0; + if (p.A_batch_c > 1 || p.A_batch_d > 1) + { + if (p.out_batch_c == 1 && p.A_batch_c > 1 && p.A_batch_d == 1) + a_c_idx = batch_idx; + else + a_c_idx = p.A_batch_c == 1 ? 0 : out_c_idx; + + a_d_idx = p.A_batch_d == 1 ? 0 : out_d_idx; + } + + uint b_c_idx = 0; + uint b_d_idx = 0; + if (p.B_batch_c > 1 || p.B_batch_d > 1) + { + if (p.out_batch_c == 1 && p.B_batch_c > 1 && p.B_batch_d == 1) + b_c_idx = batch_idx; + else + b_c_idx = p.B_batch_c == 1 ? 0 : out_c_idx; + + b_d_idx = p.B_batch_d == 1 ? 0 : out_d_idx; + } + +#if ncnn_VK_KHR_cooperative_matrix + coopmat sum = coopmat(0.f); +#elif ncnn_VK_NV_cooperative_matrix +#if NCNN_fp16_arithmetic + fcoopmatNV<16, gl_ScopeSubgroup, CM_M, CM_N> sum = fcoopmatNV<16, gl_ScopeSubgroup, CM_M, CM_N>(0.f); +#else + fcoopmatNV<32, gl_ScopeSubgroup, CM_M, CM_N> sum = fcoopmatNV<32, gl_ScopeSubgroup, CM_M, CM_N>(0.f); +#endif +#endif + + for (uint k0 = 0; k0 < uint(p.K); k0 += CM_K) + { + for (uint idx = lane; idx < CM_M * CM_K; idx += subgroup_size) + { + const uint lm = idx / CM_K; + const uint lk = idx % CM_K; + + const uint gm = m0 + lm; + const uint gk = k0 + lk; + + afp v = afp(0.f); + if (gm < uint(p.M) && gk < uint(p.K)) + { + const uint ai = get_a_index(gm, gk, a_c_idx, a_d_idx); + v = afp(buffer_ld1(A_blob_data, ai)); + } + + shA[idx] = float16_t(v); + } + + for (uint idx = lane; idx < CM_K * CM_N; idx += subgroup_size) + { + const uint lk = idx / CM_N; + const uint ln = idx % CM_N; + + const uint gk = k0 + lk; + const uint gn = n0 + ln; + + afp v = afp(0.f); + if (gk < uint(p.K) && gn < uint(p.N)) + { + const uint bi = get_b_index(gk, gn, b_c_idx, b_d_idx); + v = afp(buffer_ld1(B_blob_data, bi)); + } + + shB[idx] = float16_t(v); + } + + barrier(); + +#if ncnn_VK_KHR_cooperative_matrix + coopmat a; + coopmat b; + coopMatLoad(a, shA, 0, CM_K, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(b, shB, 0, CM_N, gl_CooperativeMatrixLayoutRowMajor); + sum = coopMatMulAdd(a, b, sum); +#elif ncnn_VK_NV_cooperative_matrix + fcoopmatNV<16, gl_ScopeSubgroup, CM_M, CM_K> a; + fcoopmatNV<16, gl_ScopeSubgroup, CM_K, CM_N> b; + coopMatLoadNV(a, shA, 0, CM_K, false); + coopMatLoadNV(b, shB, 0, CM_N, false); + sum = coopMatMulAddNV(a, b, sum); +#endif + + barrier(); + } + +#if ncnn_VK_KHR_cooperative_matrix + coopMatStore(sum, shO, 0, CM_N, gl_CooperativeMatrixLayoutRowMajor); +#elif ncnn_VK_NV_cooperative_matrix + coopMatStoreNV(sum, shO, 0, CM_N, false); +#endif + + barrier(); + + for (uint idx = lane; idx < CM_M * CM_N; idx += subgroup_size) + { + const uint lm = idx / CM_N; + const uint ln = idx % CM_N; + + const uint gm = m0 + lm; + const uint gn = n0 + ln; + + if (gm < uint(p.M) && gn < uint(p.N)) + { + const uint oi = get_o_index(gm, gn, out_c_idx, out_d_idx); + buffer_st1(top_blob_data, oi, shO[idx]); + } + } +} diff --git a/src/layer/vulkan/shader/matmul_sg.comp b/src/layer/vulkan/shader/matmul_sg.comp new file mode 100644 index 000000000000..b8de4ae928b5 --- /dev/null +++ b/src/layer/vulkan/shader/matmul_sg.comp @@ -0,0 +1,199 @@ +// Copyright 2026 MYQ +// SPDX-License-Identifier: BSD-3-Clause + +#version 450 + +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_shuffle : require +#if NCNN_fp16_storage +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + +layout(constant_id = 0) const int transB = 0; +layout(constant_id = 1) const uint subgroup_size = 32; + +layout(binding = 0) writeonly buffer top_blob { sfp top_blob_data[]; }; +layout(binding = 1) readonly buffer A_blob { sfp A_blob_data[]; }; +layout(binding = 2) readonly buffer B_blob { sfp B_blob_data[]; }; + +layout(push_constant) uniform parameter +{ + int M; + int N; + int K; + int batch; + + int A_layout; + int B_layout; + int out_layout; + + int A_hstep; + int A_cstep; + int A_dstep; + int A_batch_c; + int A_batch_d; + + int B_hstep; + int B_cstep; + int B_dstep; + int B_batch_c; + int B_batch_d; + + int out_hstep; + int out_cstep; + int out_dstep; + int out_batch_c; + int out_batch_d; + + int transB_dynamic; +} p; + +uint get_a_index(uint m, uint k, uint a_c_idx, uint a_d_idx) +{ + if (p.A_layout == 2) + return k; + + if (p.A_layout == 1) + return a_c_idx * uint(p.A_cstep) + a_d_idx * uint(p.A_dstep) + m * uint(p.A_hstep) + k; + + return m * uint(p.A_hstep) + k; +} + +uint get_b_index(uint k, uint n, uint b_c_idx, uint b_d_idx) +{ + if (p.B_layout == 2) + return k; + + const bool tb = transB == 1 || p.transB_dynamic == 1; + + if (p.B_layout == 1) + { + const uint base = b_c_idx * uint(p.B_cstep) + b_d_idx * uint(p.B_dstep); + return tb ? (base + n * uint(p.B_hstep) + k) : (base + k * uint(p.B_hstep) + n); + } + + return tb ? (n * uint(p.B_hstep) + k) : (k * uint(p.B_hstep) + n); +} + +uint get_o_index(uint m, uint n, uint out_c_idx, uint out_d_idx) +{ + if (p.out_layout == 1) + return n; + + if (p.out_layout == 2) + return m; + + if (p.out_layout == 3) + return out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_dstep) + m * uint(p.out_hstep) + n; + + if (p.out_layout == 4) + return out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_hstep) + n; + + if (p.out_layout == 5) + return out_c_idx * uint(p.out_cstep) + out_d_idx * uint(p.out_hstep) + m; + + return m * uint(p.out_hstep) + n; +} + +afp subgroup_reduce_add(afp v) +{ + // Fast path for power-of-two subgroup size. + if ((subgroup_size & (subgroup_size - 1u)) == 0u) + { + for (uint offset = subgroup_size >> 1; offset > 0; offset >>= 1) + { + v += subgroupShuffleXor(v, offset); + } + return v; + } + + afp total = afp(0.f); + for (uint i = 0; i < subgroup_size; i++) + { + total += subgroupShuffle(v, i); + } + return total; +} + +void main() +{ + const uint lane = gl_SubgroupInvocationID; + + const uint gx = gl_GlobalInvocationID.x / subgroup_size; + const uint gy = gl_GlobalInvocationID.y; + const uint gz = gl_GlobalInvocationID.z; + + if (gx >= uint(p.N) || gy >= uint(p.M) || gz >= uint(p.batch)) + return; + + uint out_c_idx = 0; + uint out_d_idx = 0; + if (p.out_batch_c > 1 || p.out_batch_d > 1) + { + out_c_idx = p.out_batch_c == 1 ? 0 : gz / uint(p.out_batch_d); + out_d_idx = p.out_batch_d == 1 ? 0 : gz % uint(p.out_batch_d); + } + + uint a_c_idx = 0; + uint a_d_idx = 0; + if (p.A_batch_c > 1 || p.A_batch_d > 1) + { + if (p.out_batch_c == 1 && p.A_batch_c > 1 && p.A_batch_d == 1) + a_c_idx = gz; + else + a_c_idx = p.A_batch_c == 1 ? 0 : out_c_idx; + + a_d_idx = p.A_batch_d == 1 ? 0 : out_d_idx; + } + + uint b_c_idx = 0; + uint b_d_idx = 0; + if (p.B_batch_c > 1 || p.B_batch_d > 1) + { + if (p.out_batch_c == 1 && p.B_batch_c > 1 && p.B_batch_d == 1) + b_c_idx = gz; + else + b_c_idx = p.B_batch_c == 1 ? 0 : out_c_idx; + + b_d_idx = p.B_batch_d == 1 ? 0 : out_d_idx; + } + + afp sum = afp(0.f); + + uint k = lane; + for (; k + subgroup_size * 3 < uint(p.K); k += subgroup_size * 4) + { + const uint ai0 = get_a_index(gy, k, a_c_idx, a_d_idx); + const uint bi0 = get_b_index(k, gx, b_c_idx, b_d_idx); + const uint k1 = k + subgroup_size; + const uint ai1 = get_a_index(gy, k1, a_c_idx, a_d_idx); + const uint bi1 = get_b_index(k1, gx, b_c_idx, b_d_idx); + const uint k2 = k1 + subgroup_size; + const uint ai2 = get_a_index(gy, k2, a_c_idx, a_d_idx); + const uint bi2 = get_b_index(k2, gx, b_c_idx, b_d_idx); + const uint k3 = k2 + subgroup_size; + const uint ai3 = get_a_index(gy, k3, a_c_idx, a_d_idx); + const uint bi3 = get_b_index(k3, gx, b_c_idx, b_d_idx); + + sum += afp(buffer_ld1(A_blob_data, ai0)) * afp(buffer_ld1(B_blob_data, bi0)); + sum += afp(buffer_ld1(A_blob_data, ai1)) * afp(buffer_ld1(B_blob_data, bi1)); + sum += afp(buffer_ld1(A_blob_data, ai2)) * afp(buffer_ld1(B_blob_data, bi2)); + sum += afp(buffer_ld1(A_blob_data, ai3)) * afp(buffer_ld1(B_blob_data, bi3)); + } + + for (; k < uint(p.K); k += subgroup_size) + { + const uint ai = get_a_index(gy, k, a_c_idx, a_d_idx); + const uint bi = get_b_index(k, gx, b_c_idx, b_d_idx); + + sum += afp(buffer_ld1(A_blob_data, ai)) * afp(buffer_ld1(B_blob_data, bi)); + } + + const afp total = subgroup_reduce_add(sum); + + if (lane == 0) + { + const uint oi = get_o_index(gy, gx, out_c_idx, out_d_idx); + buffer_st1(top_blob_data, oi, total); + } +}