@@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl(
26432643 b_ptr += args.ne21 ;
26442644 g_ptr += args.ne21 *G;
26452645
2646- if (K > 1u ) {
2646+ if (K > 1 ) {
26472647 const int target_slot = (int )t - shift;
26482648 if (target_slot >= 0 && target_slot < (int )K) {
26492649 device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
@@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl(
26552655 }
26562656 }
26572657
2658- if (K == 1u ) {
2658+ if (K == 1 ) {
26592659 device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
26602660 FOR_UNROLL (short j = 0 ; j < NSG ; j++) {
26612661 const short is = tx*NSG + j;
@@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32(
51045104 for (int64_t sx = x_min; sx < x_max; ++sx) {
51055105 const float wx = MAX (0 .0f , 1 .0f - fabs ((float )sx - f00) * invscale0);
51065106 const float w = wx * wy;
5107- const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00 );
5107+ device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00 );
51085108 sum += (*src_ptr) * w;
51095109 wsum += w;
51105110 }
@@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32(
52865286 const int64_t ix = MAX (0 , MIN (args.ne00 - 1 , i00 + dx));
52875287 const float wx = (dx == -1 ) ? w_x0 : (dx == 0 ) ? w_x1 : (dx == 1 ) ? w_x2 : w_x3;
52885288
5289- const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00 );
5289+ device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00 );
52905290 sum += (*src_ptr) * wx * wy;
52915291 }
52925292 }
@@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32(
53295329 }
53305330}
53315331
5332- kernel void kernel_pad_f32 (
5332+ template <typename T>
5333+ kernel void kernel_pad_impl (
53335334 constant ggml_metal_kargs_pad & args,
53345335 device const char * src0,
53355336 device char * dst,
53365337 uint3 tgpig[[threadgroup_position_in_grid]] ,
53375338 uint3 tpitg[[thread_position_in_threadgroup]] ,
53385339 uint3 ntg[[threads_per_threadgroup]] ) {
5340+ const int32_t i3 = tgpig.z ;
5341+ const int32_t i2 = tgpig.y ;
5342+ const int32_t k0 = tgpig.x /args.ne1 ;
5343+ const int32_t i1 = tgpig.x - k0*args.ne1 ;
53395344
5340- const int64_t i3 = tgpig. z ;
5341- const int64_t i2 = tgpig. y ;
5342- const int64_t i1 = tgpig. x ;
5345+ const int32_t i03 = i3 ;
5346+ const int32_t i02 = i2 ;
5347+ const int32_t i01 = i1 ;
53435348
5344- const int64_t i03 = i3;
5345- const int64_t i02 = i2;
5346- const int64_t i01 = i1;
5347-
5348- device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 );
5349- device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 );
5349+ device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 );
5350+ device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 );
53505351
5351- if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03 ) {
5352- for (int i0 = tpitg.x ; i0 < args.ne0 ; i0 += ntg.x ) {
5353- if (i0 < args.ne00 ) {
5354- dst_ptr[i0] = src0_ptr[i0];
5355- } else {
5356- dst_ptr[i0] = 0 .0f ;
5357- }
5352+ for (int32_t l0 = 0 ; l0 < 1024 ; l0 += ntg.x ) {
5353+ const int32_t i0 = k0*1024 + tpitg.x + l0;
5354+ if (i0 >= args.ne0 ) {
5355+ break ;
53585356 }
53595357
5360- return ;
5361- }
5362-
5363- for ( int i0 = tpitg. x ; i0 < args. ne0 ; i0 += ntg. x ) {
5364- dst_ptr[i0] = 0 . 0f ;
5358+ if (i0 < args. ne00 && i1 < args. ne01 && i2 < args. ne02 && i3 < args. ne03 ) {
5359+ dst_ptr[i0] = src0_ptr[i0];
5360+ } else {
5361+ dst_ptr[i0] = 0 . 0f ;
5362+ }
53655363 }
53665364}
53675365
5366+ typedef decltype (kernel_pad_impl<float >) kernel_pad_t;
5367+
5368+ template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float >;
5369+ template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
5370+
5371+ // TODO: this is slow - optimize
53685372kernel void kernel_pad_reflect_1d_f32 (
53695373 constant ggml_metal_kargs_pad_reflect_1d & args,
53705374 device const char * src0,
@@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t(
73287332 device const char * src0,
73297333 device char * dst,
73307334 uint3 tgpig[[threadgroup_position_in_grid]] ,
7331- ushort tiitg [[thread_index_in_threadgroup ]] ,
7335+ ushort3 tpitg [[thread_position_in_threadgroup ]] ,
73327336 ushort3 ntg[[threads_per_threadgroup]] ) {
7333- const int i03 = tgpig[2 ];
7334- const int i02 = tgpig[1 ];
7335- const int i01 = ntg[1 ] == 1 ? tgpig[0 ]%args.ne01 : tgpig[0 ]*ntg[1 ] + tiitg/ntg[0 ];
7336- const int iw0 = ntg[1 ] == 1 ? tgpig[0 ]/args.ne01 : 0 ;
7337+ const int32_t i03 = tgpig[2 ];
7338+ const int32_t i02 = tgpig[1 ];
7339+ const int32_t i01 = ntg[1 ] == 1 ? tgpig[0 ]%args.ne01 : tgpig[0 ]*ntg[1 ] + tpitg.y ;
7340+ const int32_t iw0 = ntg[1 ] == 1 ? tgpig[0 ]/args.ne01 : 0 ;
7341+
7342+ if (i01 >= args.ne01 ) {
7343+ return ;
7344+ }
73377345
73387346 const int64_t n = i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 ;
73397347
7340- const int64_t i3 = n/(args.ne2 *args.ne1 *args.ne0 );
7341- const int64_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 )/(args.ne1 *args.ne0 );
7342- const int64_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 )/args.ne0 ;
7343- const int64_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 );
7348+ const int32_t i3 = n/(args.ne2 *args.ne1 *args.ne0 );
7349+ const int32_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 )/(args.ne1 *args.ne0 );
7350+ const int32_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 )/args.ne0 ;
7351+ const int32_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 );
73447352
73457353 device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
73467354
7347- for (int64_t i00 = iw0*ntg[0 ] + tiitg%ntg[ 0 ] ; i00 < args.ne00 ; ) {
7355+ for (int32_t i00 = iw0*ntg[0 ] + tpitg. x ; i00 < args.ne00 ;) {
73487356 device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00 );
73497357 dst_data[i00] = (T1 ) src[0 ];
73507358 break ;
@@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q(
73767384 device const char * src0,
73777385 device char * dst,
73787386 uint3 tgpig[[threadgroup_position_in_grid]] ,
7379- ushort tiitg [[thread_index_in_threadgroup ]] ,
7387+ ushort3 tpitg [[thread_position_in_threadgroup ]] ,
73807388 ushort3 ntg[[threads_per_threadgroup]] ) {
7381- const int i03 = tgpig[2 ];
7382- const int i02 = tgpig[1 ];
7383- const int i01 = ntg[1 ] == 1 ? tgpig[0 ]%args.ne01 : tgpig[0 ]*ntg[1 ] + tiitg/ntg[0 ];
7384- const int iw0 = ntg[1 ] == 1 ? tgpig[0 ]/args.ne01 : 0 ;
7389+ const int32_t i03 = tgpig[2 ];
7390+ const int32_t i02 = tgpig[1 ];
7391+ const int32_t i01 = ntg[1 ] == 1 ? tgpig[0 ]%args.ne01 : tgpig[0 ]*ntg[1 ] + tpitg.y ;
7392+ const int32_t iw0 = ntg[1 ] == 1 ? tgpig[0 ]/args.ne01 : 0 ;
7393+
7394+ if (i01 >= args.ne01 ) {
7395+ return ;
7396+ }
73857397
73867398 const int64_t n = i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 ;
73877399
7388- const int64_t i3 = n / (args.ne2 *args.ne1 *args.ne0 );
7389- const int64_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 ) / (args.ne1 *args.ne0 );
7390- const int64_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 ) / args.ne0 ;
7391- const int64_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 )/QK ;
7400+ const int32_t i3 = n / (args.ne2 *args.ne1 *args.ne0 );
7401+ const int32_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 ) / (args.ne1 *args.ne0 );
7402+ const int32_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 ) / args.ne0 ;
7403+ const int32_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 )/QK ;
73927404
73937405 device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
73947406
7395- for (int64_t i00 = iw0*ntg[0 ] + tiitg%ntg[ 0 ] ; i00 < args.nk0 ; ) {
7407+ for (int32_t i00 = iw0*ntg[0 ] + tpitg. x ; i00 < args.nk0 ;) {
73967408 device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK )*args.nb00 );
73977409
73987410 quantize_func (src, dst_data[i00]);
@@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32(
74177429 device const char * src0,
74187430 device char * dst,
74197431 uint3 tgpig[[threadgroup_position_in_grid]] ,
7420- ushort tiitg [[thread_index_in_threadgroup ]] ,
7432+ ushort3 tpitg [[thread_position_in_threadgroup ]] ,
74217433 ushort3 ntg[[threads_per_threadgroup]] ) {
7422- const int i03 = tgpig[2 ];
7423- const int i02 = tgpig[1 ];
7424- const int i01 = ntg[1 ] == 1 ? tgpig[0 ]%args.ne01 : tgpig[0 ]*ntg[1 ] + tiitg/ntg[0 ];
7425- const int iw0 = ntg[1 ] == 1 ? tgpig[0 ]/args.ne01 : 0 ;
7434+ const int32_t i03 = tgpig[2 ];
7435+ const int32_t i02 = tgpig[1 ];
7436+ const int32_t i01 = ntg[1 ] == 1 ? tgpig[0 ]%args.ne01 : tgpig[0 ]*ntg[1 ] + tpitg.y ;
7437+ const int32_t iw0 = ntg[1 ] == 1 ? tgpig[0 ]/args.ne01 : 0 ;
7438+
7439+ if (i01 >= args.ne01 ) {
7440+ return ;
7441+ }
74267442
74277443 const int64_t n = i03*args.ne02 *args.ne01 *args.ne00 + i02*args.ne01 *args.ne00 + i01*args.ne00 ;
74287444
7429- const int64_t i3 = n/(args.ne2 *args.ne1 *args.ne0 );
7430- const int64_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 )/(args.ne1 *args.ne0 );
7431- const int64_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 )/args.ne0 ;
7432- const int64_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 );
7445+ const int32_t i3 = n/(args.ne2 *args.ne1 *args.ne0 );
7446+ const int32_t i2 = (n - i3*args.ne2 *args.ne1 *args.ne0 )/(args.ne1 *args.ne0 );
7447+ const int32_t i1 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 )/args.ne0 ;
7448+ const int32_t i0 = (n - i3*args.ne2 *args.ne1 *args.ne0 - i2*args.ne1 *args.ne0 - i1*args.ne0 );
74337449
74347450 device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 );
74357451 device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
74367452
7437- for (int64_t i00 = iw0*ntg[0 ] + tiitg%ntg[ 0 ] ; i00 < args.nk0 ; ) {
7453+ for (int32_t i00 = iw0*ntg[0 ] + tpitg. x ; i00 < args.nk0 ;) {
74387454 T4x4 temp;
74397455 dequantize_func (src_data + i00/nl, i00%nl, temp);
74407456 dst_data[i00] = temp;
0 commit comments