Skip to content

Commit 2d9b7c8

Browse files
authored
metal : restore im2col implementation for large kernels (ggml-org#23901)
1 parent e674b12 commit 2d9b7c8

4 files changed

Lines changed: 79 additions & 61 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1732,14 +1732,20 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
17321732
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
17331733
assert(op->op == GGML_OP_IM2COL);
17341734

1735+
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1736+
17351737
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
17361738
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
17371739
GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
17381740

17391741
char base[256];
17401742
char name[256];
17411743

1742-
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1744+
if (ne00*ne01 <= 1024) {
1745+
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1746+
} else {
1747+
snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
1748+
}
17431749
snprintf(name, 256, "%s", base);
17441750

17451751
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3635,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
36353635

36363636
auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
36373637

3638-
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3638+
if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3639+
const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
36393640

3640-
const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3641+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3642+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3643+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3644+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
36413645

3642-
ggml_metal_encoder_set_pipeline(enc, pipeline);
3643-
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3644-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3645-
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3646+
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3647+
} else {
3648+
const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
3649+
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
3650+
3651+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3652+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3653+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3654+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
36463655

3647-
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3656+
ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
3657+
}
36483658

36493659
return 1;
36503660
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4696,59 +4696,59 @@ kernel void kernel_im2col(
46964696
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
46974697
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
46984698

4699-
// TODO: obsolete -- remove
4700-
//typedef void (im2col_ext_t)(
4701-
// constant ggml_metal_kargs_im2col & args,
4702-
// device const float * x,
4703-
// device char * dst,
4704-
// uint3 tgpig[[threadgroup_position_in_grid]],
4705-
// uint3 tgpg[[threadgroups_per_grid]],
4706-
// uint3 tpitg[[thread_position_in_threadgroup]],
4707-
// uint3 ntg[[threads_per_threadgroup]]);
4708-
//
4709-
//template <typename T>
4710-
//kernel void kernel_im2col_ext(
4711-
// constant ggml_metal_kargs_im2col & args,
4712-
// device const float * x,
4713-
// device char * dst,
4714-
// uint3 tgpig[[threadgroup_position_in_grid]],
4715-
// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4716-
// uint3 tpitg[[thread_position_in_threadgroup]],
4717-
// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4718-
// const int64_t KHW = (int64_t)args.KHW;
4719-
//
4720-
// const int64_t d = tgpig[0] / args.CHW;
4721-
// const int64_t chw = tgpig[0] % args.CHW;
4722-
// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4723-
// const int64_t HW = tgpig[0] % KHW;
4724-
//
4725-
// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4726-
// if (tpitg_0 >= args.N) {
4727-
// return;
4728-
// }
4729-
//
4730-
// const int64_t tpitg_1 = HW / args.KW;
4731-
// const int64_t tpitg_2 = HW % args.KW;
4732-
//
4733-
// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4734-
// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4735-
//
4736-
// const int64_t offset_dst =
4737-
// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4738-
// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4739-
//
4740-
// device T * pdst = (device T *) (dst);
4741-
//
4742-
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4743-
// pdst[offset_dst] = 0.0f;
4744-
// } else {
4745-
// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4746-
// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4747-
// }
4748-
//}
4749-
//
4750-
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4751-
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4699+
// TODO: optimize
4700+
typedef void (im2col_ext_t)(
4701+
constant ggml_metal_kargs_im2col & args,
4702+
device const float * x,
4703+
device char * dst,
4704+
uint3 tgpig[[threadgroup_position_in_grid]],
4705+
uint3 tgpg[[threadgroups_per_grid]],
4706+
uint3 tpitg[[thread_position_in_threadgroup]],
4707+
uint3 ntg[[threads_per_threadgroup]]);
4708+
4709+
template <typename T>
4710+
kernel void kernel_im2col_ext(
4711+
constant ggml_metal_kargs_im2col & args,
4712+
device const float * x,
4713+
device char * dst,
4714+
uint3 tgpig[[threadgroup_position_in_grid]],
4715+
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4716+
uint3 tpitg[[thread_position_in_threadgroup]],
4717+
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4718+
const int64_t KHW = (int64_t)args.KHW;
4719+
4720+
const int64_t d = tgpig[0] / args.CHW;
4721+
const int64_t chw = tgpig[0] % args.CHW;
4722+
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4723+
const int64_t HW = tgpig[0] % KHW;
4724+
4725+
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4726+
if (tpitg_0 >= args.N) {
4727+
return;
4728+
}
4729+
4730+
const int64_t tpitg_1 = HW / args.KW;
4731+
const int64_t tpitg_2 = HW % args.KW;
4732+
4733+
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4734+
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4735+
4736+
const int64_t offset_dst =
4737+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4738+
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4739+
4740+
device T * pdst = (device T *) (dst);
4741+
4742+
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4743+
pdst[offset_dst] = 0.0f;
4744+
} else {
4745+
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4746+
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4747+
}
4748+
}
4749+
4750+
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4751+
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
47524752

47534753
template <typename TK>
47544754
kernel void kernel_conv_2d(

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7812,6 +7812,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
78127812
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
78137813
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
78147814
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
7815+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {128, 128, 1, 2}, {32, 33, 1, 2}, 1, 1, 1, 1, 1, 1, true));
7816+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {128, 128, 2, 1}, {33, 34, 2, 1}, 1, 1, 1, 1, 1, 1, true));
78157817

78167818
// im2col 3D
78177819
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));

0 commit comments

Comments
 (0)