@@ -4696,59 +4696,59 @@ kernel void kernel_im2col(
46964696template [[host_name(" kernel_im2col_f32" )]] kernel im2col_t kernel_im2col<float >;
46974697template [[host_name(" kernel_im2col_f16" )]] kernel im2col_t kernel_im2col<half>;
46984698
4699- // TODO: obsolete -- remove
4700- // typedef void (im2col_ext_t)(
4701- // constant ggml_metal_kargs_im2col & args,
4702- // device const float * x,
4703- // device char * dst,
4704- // uint3 tgpig[[threadgroup_position_in_grid]],
4705- // uint3 tgpg[[threadgroups_per_grid]],
4706- // uint3 tpitg[[thread_position_in_threadgroup]],
4707- // uint3 ntg[[threads_per_threadgroup]]);
4708- //
4709- // template <typename T>
4710- // kernel void kernel_im2col_ext(
4711- // constant ggml_metal_kargs_im2col & args,
4712- // device const float * x,
4713- // device char * dst,
4714- // uint3 tgpig[[threadgroup_position_in_grid]],
4715- // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4716- // uint3 tpitg[[thread_position_in_threadgroup]],
4717- // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4718- // const int64_t KHW = (int64_t)args.KHW;
4719- //
4720- // const int64_t d = tgpig[0] / args.CHW;
4721- // const int64_t chw = tgpig[0] % args.CHW;
4722- // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4723- // const int64_t HW = tgpig[0] % KHW;
4724- //
4725- // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4726- // if (tpitg_0 >= args.N) {
4727- // return;
4728- // }
4729- //
4730- // const int64_t tpitg_1 = HW / args.KW;
4731- // const int64_t tpitg_2 = HW % args.KW;
4732- //
4733- // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4734- // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4735- //
4736- // const int64_t offset_dst =
4737- // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4738- // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4739- //
4740- // device T * pdst = (device T *) (dst);
4741- //
4742- // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4743- // pdst[offset_dst] = 0.0f;
4744- // } else {
4745- // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4746- // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4747- // }
4748- // }
4749- //
4750- // template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4751- // template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4699+ // TODO: optimize
4700+ typedef void (im2col_ext_t )(
4701+ constant ggml_metal_kargs_im2col & args,
4702+ device const float * x,
4703+ device char * dst,
4704+ uint3 tgpig[[threadgroup_position_in_grid]],
4705+ uint3 tgpg[[threadgroups_per_grid]],
4706+ uint3 tpitg[[thread_position_in_threadgroup]],
4707+ uint3 ntg[[threads_per_threadgroup]]);
4708+
4709+ template <typename T>
4710+ kernel void kernel_im2col_ext (
4711+ constant ggml_metal_kargs_im2col & args,
4712+ device const float * x,
4713+ device char * dst,
4714+ uint3 tgpig[[threadgroup_position_in_grid]],
4715+ uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4716+ uint3 tpitg[[thread_position_in_threadgroup]],
4717+ uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4718+ const int64_t KHW = (int64_t )args.KHW ;
4719+
4720+ const int64_t d = tgpig[0 ] / args.CHW ;
4721+ const int64_t chw = tgpig[0 ] % args.CHW ;
4722+ const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4723+ const int64_t HW = tgpig[0 ] % KHW;
4724+
4725+ const int64_t tpitg_0 = (d * ntg[0 ]) + tpitg[0 ];
4726+ if (tpitg_0 >= args.N ) {
4727+ return ;
4728+ }
4729+
4730+ const int64_t tpitg_1 = HW / args.KW ;
4731+ const int64_t tpitg_2 = HW % args.KW ;
4732+
4733+ const int64_t iiw = tgpig[2 ] * args.s0 + tpitg_2 * args.d0 - args.p0 ;
4734+ const int64_t iih = tgpig[1 ] * args.s1 + tpitg_1 * args.d1 - args.p1 ;
4735+
4736+ const int64_t offset_dst =
4737+ (tpitg_0 * tgpg[1 ] * tgpg[2 ] + tgpig[1 ] * tgpg[2 ] + tgpig[2 ]) * args.CHW +
4738+ (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4739+
4740+ device T * pdst = (device T *) (dst);
4741+
4742+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW ) {
4743+ pdst[offset_dst] = 0 .0f ;
4744+ } else {
4745+ const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1 ;
4746+ pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4747+ }
4748+ }
4749+
4750+ template [[host_name(" kernel_im2col_ext_f32" )]] kernel im2col_ext_t kernel_im2col_ext<float >;
4751+ template [[host_name(" kernel_im2col_ext_f16" )]] kernel im2col_ext_t kernel_im2col_ext<half>;
47524752
47534753template <typename TK>
47544754kernel void kernel_conv_2d (
0 commit comments