@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
493493 }
494494};
495495
496+ struct vk_conv3d_pipeline_state {
497+ vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
498+ uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
499+ : s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
500+
501+ uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
502+ uint32_t aligned;
503+
504+ bool operator<(const vk_conv3d_pipeline_state &b) const {
505+ return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
506+ std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
507+ }
508+ };
509+
496510struct vk_solve_tri_pipeline_state {
497511 vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
498512 : N(N), K(K) {}
@@ -924,6 +938,8 @@ struct vk_device_struct {
924938 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
925939 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
926940 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
941+ std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
942+ std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
927943 vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
928944 vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
929945
@@ -1669,6 +1685,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
16691685 init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
16701686}
16711687
1688+ struct vk_op_conv3d_push_constants {
1689+ uint32_t OC;
1690+ uint32_t IC;
1691+ uint32_t N;
1692+
1693+ uint32_t IW;
1694+ uint32_t IH;
1695+ uint32_t ID;
1696+ uint32_t OW;
1697+ uint32_t OH;
1698+ uint32_t OD;
1699+
1700+ uint32_t nb01;
1701+ uint32_t nb02;
1702+ uint32_t nb03;
1703+
1704+ uint32_t nb11;
1705+ uint32_t nb12;
1706+ uint32_t nb13;
1707+
1708+ uint32_t nb1;
1709+ uint32_t nb2;
1710+ uint32_t nb3;
1711+
1712+ uint32_t OWmp; uint32_t OWL;
1713+ uint32_t OWOHmp; uint32_t OWOHL;
1714+ uint32_t OWOHODmp; uint32_t OWOHODL;
1715+ };
1716+
1717+ template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
1718+ init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1719+ init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1720+ init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
1721+ }
1722+
16721723struct vk_op_conv2d_dw_push_constants {
16731724 uint32_t ne;
16741725 uint32_t batches;
@@ -5330,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
53305381
53315382 ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
53325383
5333- // conv2d, conv_transpose_2d
5384+ // conv2d, conv_transpose_2d, conv3d
53345385 for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
53355386 // smaller WG for the small-tile fallback gives more concurrent WGs per SM
53365387 uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5393,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
53935444 return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
53945445 };
53955446
5396- // coopmat1 needs to store the output through shared memory, so check up front
5397- // whether it'll fit and disable it before applying coopmat1 parameters .
5447+ // 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
5448+ // layout. cm1 needs Csh for output, so check before applying cm1 params .
53985449 if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
53995450 conv2d_use_cm1 = false;
54005451 }
@@ -5486,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
54865537 }
54875538#undef CREATE_CONV
54885539#undef CREATE_CONVS
5540+
5541+ std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
5542+ #define CREATE_CONV3D(type_suffix, spv_suffix) \
5543+ for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
5544+ const vk_conv3d_pipeline_state &state = c.first; \
5545+ std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
5546+ spec_constants_cpy.push_back(state.s0); \
5547+ spec_constants_cpy.push_back(state.s1); \
5548+ spec_constants_cpy.push_back(state.s2); \
5549+ spec_constants_cpy.push_back(state.p0); \
5550+ spec_constants_cpy.push_back(state.p1); \
5551+ spec_constants_cpy.push_back(state.p2); \
5552+ spec_constants_cpy.push_back(state.d0); \
5553+ spec_constants_cpy.push_back(state.d1); \
5554+ spec_constants_cpy.push_back(state.d2); \
5555+ spec_constants_cpy.push_back(state.KW); \
5556+ spec_constants_cpy.push_back(state.KH); \
5557+ spec_constants_cpy.push_back(state.KD); \
5558+ spec_constants_cpy.push_back(state.aligned); \
5559+ spec_constants_cpy.push_back(conv2d_csh_store); \
5560+ spec_constants_cpy.push_back(conv2d_WM); \
5561+ spec_constants_cpy.push_back(conv2d_WN); \
5562+ ggml_vk_create_pipeline( \
5563+ device, c.second, "conv3d" #type_suffix, \
5564+ conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
5565+ sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
5566+ }
5567+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
5568+ if (device->coopmat2) {
5569+ CREATE_CONV3D(_f32, _cm2)
5570+ CREATE_CONV3D(_f16_f32, _cm2)
5571+ } else
5572+ #endif
5573+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5574+ if (conv2d_use_cm1) {
5575+ CREATE_CONV3D(_f32, _cm1)
5576+ CREATE_CONV3D(_f16_f32, _cm1)
5577+ } else
5578+ #endif
5579+ if (conv2d_UNROLL) {
5580+ CREATE_CONV3D(_f32, _unroll)
5581+ CREATE_CONV3D(_f16_f32, _unroll)
5582+ } else {
5583+ CREATE_CONV3D(_f32, )
5584+ CREATE_CONV3D(_f16_f32, )
5585+ }
5586+ #undef CREATE_CONV3D
54895587 }
54905588
54915589 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10901,6 +10999,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1090110999 }
1090211000 }
1090311001 return nullptr;
11002+ case GGML_OP_CONV_3D:
11003+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
11004+ const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
11005+ const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
11006+ const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
11007+ const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
11008+ const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
11009+
11010+ const uint32_t KW = (uint32_t)src0->ne[0];
11011+ const uint32_t KH = (uint32_t)src0->ne[1];
11012+ const uint32_t KD = (uint32_t)src0->ne[2];
11013+ const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
11014+ const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
11015+ const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
11016+ const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
11017+ const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
11018+ const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
11019+ const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
11020+ const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
11021+ const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
11022+
11023+ const uint32_t CRS = IC * KW * KH * KD;
11024+ const uint32_t BS_K = vk_conv_block_sizes[shape].K;
11025+ const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
11026+ const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
11027+ const uint32_t aligned = ((OC % BS_K == 0) &&
11028+ (CRS % BS_CRS == 0) &&
11029+ (NPQ % BS_NPQ == 0)) ? 1u : 0u;
11030+
11031+ vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
11032+
11033+ std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
11034+ if (src0->type == GGML_TYPE_F32) {
11035+ pipelines = &ctx->device->pipeline_conv3d_f32[shape];
11036+ } else if (src0->type == GGML_TYPE_F16) {
11037+ pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
11038+ } else {
11039+ return nullptr;
11040+ }
11041+
11042+ vk_pipeline pipeline = nullptr;
11043+
11044+ {
11045+ std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
11046+ auto it = pipelines->find(conv3d_pipeline_state);
11047+ if (it != pipelines->end()) {
11048+ pipeline = it->second;
11049+ } else {
11050+ (*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
11051+ }
11052+ }
11053+
11054+ return pipeline;
11055+ }
11056+ return nullptr;
1090411057 case GGML_OP_ADD1:
1090511058 if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
1090611059 return ctx->device->pipeline_add1_f16_f16;
@@ -11236,6 +11389,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1123611389 GGML_ABORT("invalid push constant type for CONV_2D");
1123711390 }
1123811391 break;
11392+ case GGML_OP_CONV_3D:
11393+ if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
11394+ const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
11395+ const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
11396+ const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
11397+
11398+ elements = { pc.OC, NPQ_blocks, 1 };
11399+ if (elements[1] > 512) {
11400+ elements[2] = CEIL_DIV(elements[1], 512);
11401+ elements[1] = 512;
11402+ }
11403+ } else {
11404+ GGML_ABORT("invalid push constant type for CONV_3D");
11405+ }
11406+ break;
1123911407 case GGML_OP_ADD:
1124011408 case GGML_OP_SUB:
1124111409 case GGML_OP_DIV:
@@ -13134,6 +13302,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
1313413302 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
1313513303}
1313613304
13305+ static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
13306+ const ggml_tensor * src1, ggml_tensor * dst) {
13307+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
13308+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
13309+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
13310+
13311+ GGML_TENSOR_BINARY_OP_LOCALS
13312+ GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
13313+ GGML_ASSERT(nb10 == sizeof(float));
13314+ GGML_ASSERT(nb0 == sizeof(float));
13315+
13316+ vk_op_conv3d_push_constants p{};
13317+ p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
13318+ p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
13319+ p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
13320+ GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
13321+ GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
13322+ GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
13323+
13324+ p.IW = static_cast<uint32_t>(ne10);
13325+ p.IH = static_cast<uint32_t>(ne11);
13326+ p.ID = static_cast<uint32_t>(ne12);
13327+ p.OW = static_cast<uint32_t>(ne0);
13328+ p.OH = static_cast<uint32_t>(ne1);
13329+ p.OD = static_cast<uint32_t>(ne2);
13330+
13331+ // the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
13332+ // total input element count must fit in a uint32.
13333+ GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
13334+
13335+ p.nb01 = static_cast<uint32_t>(nb01 / nb00);
13336+ p.nb02 = static_cast<uint32_t>(nb02 / nb00);
13337+ p.nb03 = static_cast<uint32_t>(nb03 / nb00);
13338+
13339+ p.nb11 = static_cast<uint32_t>(nb11 / nb10);
13340+ p.nb12 = static_cast<uint32_t>(nb12 / nb10);
13341+ p.nb13 = static_cast<uint32_t>(nb13 / nb10);
13342+
13343+ p.nb1 = static_cast<uint32_t>(nb1 / nb0);
13344+ p.nb2 = static_cast<uint32_t>(nb2 / nb0);
13345+ p.nb3 = static_cast<uint32_t>(nb3 / nb0);
13346+
13347+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
13348+ }
13349+
1313713350static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1313813351 vk_op_conv2d_dw_push_constants p{};
1313913352 p.ne = ggml_nelements(dst);
@@ -14531,6 +14744,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1453114744 case GGML_OP_CONV_TRANSPOSE_2D:
1453214745 ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
1453314746
14747+ break;
14748+ case GGML_OP_CONV_3D:
14749+ ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
14750+
1453414751 break;
1453514752 case GGML_OP_CONV_2D_DW:
1453614753 ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -17301,6 +17518,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1730117518 ggml_is_contiguous(op->src[1]) &&
1730217519 ggml_is_contiguous(op));
1730317520 }
17521+ case GGML_OP_CONV_3D:
17522+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
17523+ op->src[1]->type == GGML_TYPE_F32 &&
17524+ op->type == GGML_TYPE_F32 &&
17525+ ggml_is_contiguous(op->src[0]) &&
17526+ ggml_is_contiguous(op->src[1]) &&
17527+ ggml_is_contiguous(op);
1730417528 default:
1730517529 return false;
1730617530 }
@@ -18144,6 +18368,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1814418368 const int32_t d0 = tensor->op_params[4];
1814518369 const int32_t d1 = tensor->op_params[5];
1814618370 tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
18371+ } else if (tensor->op == GGML_OP_CONV_3D) {
18372+ const int32_t s0 = tensor->op_params[0];
18373+ const int32_t s1 = tensor->op_params[1];
18374+ const int32_t s2 = tensor->op_params[2];
18375+ const int32_t p0 = tensor->op_params[3];
18376+ const int32_t p1 = tensor->op_params[4];
18377+ const int32_t p2 = tensor->op_params[5];
18378+ const int32_t d0 = tensor->op_params[6];
18379+ const int32_t d1 = tensor->op_params[7];
18380+ const int32_t d2 = tensor->op_params[8];
18381+ const int32_t IC = tensor->op_params[9];
18382+ const int32_t N = tensor->op_params[10];
18383+ const int32_t OC = tensor->op_params[11];
18384+ tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
1814718385 } else if (tensor->op == GGML_OP_CONV_2D_DW) {
1814818386 const int32_t s0 = tensor->op_params[0];
1814918387 const int32_t s1 = tensor->op_params[1];
0 commit comments