Skip to content

Commit c560636

Browse files
authored
vulkan: support CONV_3D (#24612)
* vulkan: support CONV_3D This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex and cleaned up by me. * disable slower perf tests
1 parent 0eb874d commit c560636

4 files changed

Lines changed: 725 additions & 3 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 241 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
493493
}
494494
};
495495

496+
struct vk_conv3d_pipeline_state {
497+
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
498+
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
499+
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
500+
501+
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
502+
uint32_t aligned;
503+
504+
bool operator<(const vk_conv3d_pipeline_state &b) const {
505+
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
506+
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
507+
}
508+
};
509+
496510
struct vk_solve_tri_pipeline_state {
497511
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
498512
: N(N), K(K) {}
@@ -924,6 +938,8 @@ struct vk_device_struct {
924938
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
925939
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
926940
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
941+
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
942+
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
927943
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
928944
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
929945

@@ -1669,6 +1685,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
16691685
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
16701686
}
16711687

1688+
struct vk_op_conv3d_push_constants {
1689+
uint32_t OC;
1690+
uint32_t IC;
1691+
uint32_t N;
1692+
1693+
uint32_t IW;
1694+
uint32_t IH;
1695+
uint32_t ID;
1696+
uint32_t OW;
1697+
uint32_t OH;
1698+
uint32_t OD;
1699+
1700+
uint32_t nb01;
1701+
uint32_t nb02;
1702+
uint32_t nb03;
1703+
1704+
uint32_t nb11;
1705+
uint32_t nb12;
1706+
uint32_t nb13;
1707+
1708+
uint32_t nb1;
1709+
uint32_t nb2;
1710+
uint32_t nb3;
1711+
1712+
uint32_t OWmp; uint32_t OWL;
1713+
uint32_t OWOHmp; uint32_t OWOHL;
1714+
uint32_t OWOHODmp; uint32_t OWOHODL;
1715+
};
1716+
1717+
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
1718+
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1719+
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1720+
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
1721+
}
1722+
16721723
struct vk_op_conv2d_dw_push_constants {
16731724
uint32_t ne;
16741725
uint32_t batches;
@@ -5330,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
53305381

53315382
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
53325383

5333-
// conv2d, conv_transpose_2d
5384+
// conv2d, conv_transpose_2d, conv3d
53345385
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
53355386
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
53365387
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
@@ -5393,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
53935444
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
53945445
};
53955446

5396-
// coopmat1 needs to store the output through shared memory, so check up front
5397-
// whether it'll fit and disable it before applying coopmat1 parameters.
5447+
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
5448+
// layout. cm1 needs Csh for output, so check before applying cm1 params.
53985449
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
53995450
conv2d_use_cm1 = false;
54005451
}
@@ -5486,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
54865537
}
54875538
#undef CREATE_CONV
54885539
#undef CREATE_CONVS
5540+
5541+
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
5542+
#define CREATE_CONV3D(type_suffix, spv_suffix) \
5543+
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
5544+
const vk_conv3d_pipeline_state &state = c.first; \
5545+
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
5546+
spec_constants_cpy.push_back(state.s0); \
5547+
spec_constants_cpy.push_back(state.s1); \
5548+
spec_constants_cpy.push_back(state.s2); \
5549+
spec_constants_cpy.push_back(state.p0); \
5550+
spec_constants_cpy.push_back(state.p1); \
5551+
spec_constants_cpy.push_back(state.p2); \
5552+
spec_constants_cpy.push_back(state.d0); \
5553+
spec_constants_cpy.push_back(state.d1); \
5554+
spec_constants_cpy.push_back(state.d2); \
5555+
spec_constants_cpy.push_back(state.KW); \
5556+
spec_constants_cpy.push_back(state.KH); \
5557+
spec_constants_cpy.push_back(state.KD); \
5558+
spec_constants_cpy.push_back(state.aligned); \
5559+
spec_constants_cpy.push_back(conv2d_csh_store); \
5560+
spec_constants_cpy.push_back(conv2d_WM); \
5561+
spec_constants_cpy.push_back(conv2d_WN); \
5562+
ggml_vk_create_pipeline( \
5563+
device, c.second, "conv3d" #type_suffix, \
5564+
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
5565+
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
5566+
}
5567+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
5568+
if (device->coopmat2) {
5569+
CREATE_CONV3D(_f32, _cm2)
5570+
CREATE_CONV3D(_f16_f32, _cm2)
5571+
} else
5572+
#endif
5573+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5574+
if (conv2d_use_cm1) {
5575+
CREATE_CONV3D(_f32, _cm1)
5576+
CREATE_CONV3D(_f16_f32, _cm1)
5577+
} else
5578+
#endif
5579+
if (conv2d_UNROLL) {
5580+
CREATE_CONV3D(_f32, _unroll)
5581+
CREATE_CONV3D(_f16_f32, _unroll)
5582+
} else {
5583+
CREATE_CONV3D(_f32, )
5584+
CREATE_CONV3D(_f16_f32, )
5585+
}
5586+
#undef CREATE_CONV3D
54895587
}
54905588

54915589
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -10901,6 +10999,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1090110999
}
1090211000
}
1090311001
return nullptr;
11002+
case GGML_OP_CONV_3D:
11003+
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
11004+
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
11005+
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
11006+
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
11007+
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
11008+
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
11009+
11010+
const uint32_t KW = (uint32_t)src0->ne[0];
11011+
const uint32_t KH = (uint32_t)src0->ne[1];
11012+
const uint32_t KD = (uint32_t)src0->ne[2];
11013+
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
11014+
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
11015+
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
11016+
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
11017+
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
11018+
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
11019+
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
11020+
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
11021+
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
11022+
11023+
const uint32_t CRS = IC * KW * KH * KD;
11024+
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
11025+
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
11026+
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
11027+
const uint32_t aligned = ((OC % BS_K == 0) &&
11028+
(CRS % BS_CRS == 0) &&
11029+
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
11030+
11031+
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
11032+
11033+
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
11034+
if (src0->type == GGML_TYPE_F32) {
11035+
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
11036+
} else if (src0->type == GGML_TYPE_F16) {
11037+
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
11038+
} else {
11039+
return nullptr;
11040+
}
11041+
11042+
vk_pipeline pipeline = nullptr;
11043+
11044+
{
11045+
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
11046+
auto it = pipelines->find(conv3d_pipeline_state);
11047+
if (it != pipelines->end()) {
11048+
pipeline = it->second;
11049+
} else {
11050+
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
11051+
}
11052+
}
11053+
11054+
return pipeline;
11055+
}
11056+
return nullptr;
1090411057
case GGML_OP_ADD1:
1090511058
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
1090611059
return ctx->device->pipeline_add1_f16_f16;
@@ -11236,6 +11389,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1123611389
GGML_ABORT("invalid push constant type for CONV_2D");
1123711390
}
1123811391
break;
11392+
case GGML_OP_CONV_3D:
11393+
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
11394+
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
11395+
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
11396+
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
11397+
11398+
elements = { pc.OC, NPQ_blocks, 1 };
11399+
if (elements[1] > 512) {
11400+
elements[2] = CEIL_DIV(elements[1], 512);
11401+
elements[1] = 512;
11402+
}
11403+
} else {
11404+
GGML_ABORT("invalid push constant type for CONV_3D");
11405+
}
11406+
break;
1123911407
case GGML_OP_ADD:
1124011408
case GGML_OP_SUB:
1124111409
case GGML_OP_DIV:
@@ -13134,6 +13302,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
1313413302
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
1313513303
}
1313613304

13305+
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
13306+
const ggml_tensor * src1, ggml_tensor * dst) {
13307+
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
13308+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
13309+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
13310+
13311+
GGML_TENSOR_BINARY_OP_LOCALS
13312+
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
13313+
GGML_ASSERT(nb10 == sizeof(float));
13314+
GGML_ASSERT(nb0 == sizeof(float));
13315+
13316+
vk_op_conv3d_push_constants p{};
13317+
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
13318+
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
13319+
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
13320+
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
13321+
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
13322+
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
13323+
13324+
p.IW = static_cast<uint32_t>(ne10);
13325+
p.IH = static_cast<uint32_t>(ne11);
13326+
p.ID = static_cast<uint32_t>(ne12);
13327+
p.OW = static_cast<uint32_t>(ne0);
13328+
p.OH = static_cast<uint32_t>(ne1);
13329+
p.OD = static_cast<uint32_t>(ne2);
13330+
13331+
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
13332+
// total input element count must fit in a uint32.
13333+
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
13334+
13335+
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
13336+
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
13337+
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
13338+
13339+
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
13340+
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
13341+
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
13342+
13343+
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
13344+
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
13345+
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
13346+
13347+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
13348+
}
13349+
1313713350
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1313813351
vk_op_conv2d_dw_push_constants p{};
1313913352
p.ne = ggml_nelements(dst);
@@ -14531,6 +14744,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1453114744
case GGML_OP_CONV_TRANSPOSE_2D:
1453214745
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
1453314746

14747+
break;
14748+
case GGML_OP_CONV_3D:
14749+
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
14750+
1453414751
break;
1453514752
case GGML_OP_CONV_2D_DW:
1453614753
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
@@ -17301,6 +17518,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1730117518
ggml_is_contiguous(op->src[1]) &&
1730217519
ggml_is_contiguous(op));
1730317520
}
17521+
case GGML_OP_CONV_3D:
17522+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
17523+
op->src[1]->type == GGML_TYPE_F32 &&
17524+
op->type == GGML_TYPE_F32 &&
17525+
ggml_is_contiguous(op->src[0]) &&
17526+
ggml_is_contiguous(op->src[1]) &&
17527+
ggml_is_contiguous(op);
1730417528
default:
1730517529
return false;
1730617530
}
@@ -18144,6 +18368,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1814418368
const int32_t d0 = tensor->op_params[4];
1814518369
const int32_t d1 = tensor->op_params[5];
1814618370
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
18371+
} else if (tensor->op == GGML_OP_CONV_3D) {
18372+
const int32_t s0 = tensor->op_params[0];
18373+
const int32_t s1 = tensor->op_params[1];
18374+
const int32_t s2 = tensor->op_params[2];
18375+
const int32_t p0 = tensor->op_params[3];
18376+
const int32_t p1 = tensor->op_params[4];
18377+
const int32_t p2 = tensor->op_params[5];
18378+
const int32_t d0 = tensor->op_params[6];
18379+
const int32_t d1 = tensor->op_params[7];
18380+
const int32_t d2 = tensor->op_params[8];
18381+
const int32_t IC = tensor->op_params[9];
18382+
const int32_t N = tensor->op_params[10];
18383+
const int32_t OC = tensor->op_params[11];
18384+
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
1814718385
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
1814818386
const int32_t s0 = tensor->op_params[0];
1814918387
const int32_t s1 = tensor->op_params[1];

0 commit comments

Comments
 (0)