Skip to content

Commit 2d6d422

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents 41e2b85 + 57ebaf4 commit 2d6d422

4 files changed

Lines changed: 94 additions & 67 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
18971897
char base[256];
18981898
char name[256];
18991899

1900-
snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
1900+
// note: this is slower
1901+
//const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0;
1902+
const bool is_c4 = false;
1903+
1904+
snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : "");
19011905
snprintf(name, 256, "%s", base);
19021906

19031907
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l
19071911

19081912
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
19091913

1914+
res.c4 = is_c4;
1915+
19101916
return res;
19111917
}
19121918

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
816816
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
817817
} else {
818818
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
819-
820819
const int nth = MIN(args.ne00, nth_max);
821-
822820
const int nk0 = (args.ne00 + nth - 1)/nth;
823821

824822
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
@@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
18631861
nk0 = ne00/ggml_blck_size(op->type);
18641862
}
18651863

1866-
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1864+
int nth = std::min<int>(nk0*ne01, 256);
18671865

18681866
// when rows are small, we can batch them together in a single threadgroup
18691867
int nrptg = 1;
@@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
18741872
nrptg = (nth + nk0 - 1)/nk0;
18751873
nth = nk0;
18761874

1877-
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1875+
if (nrptg*nth > 256) {
18781876
nrptg--;
18791877
}
18801878
}
@@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
40394037

40404038
auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
40414039

4042-
const int nth = std::min(1024, ne0);
4040+
if (pipeline.c4) {
4041+
args.ne00 = ne00/4;
4042+
args.ne0 = ne0/4;
4043+
}
4044+
4045+
const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4046+
const int nth = MIN(args.ne0, nth_max);
4047+
const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!
40434048

40444049
ggml_metal_encoder_set_pipeline(enc, pipeline);
40454050
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
40464051
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
40474052
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
40484053

4049-
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
4054+
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);
40504055

40514056
return 1;
40524057
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl(
26432643
b_ptr += args.ne21;
26442644
g_ptr += args.ne21*G;
26452645

2646-
if (K > 1u) {
2646+
if (K > 1) {
26472647
const int target_slot = (int)t - shift;
26482648
if (target_slot >= 0 && target_slot < (int)K) {
26492649
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
@@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl(
26552655
}
26562656
}
26572657

2658-
if (K == 1u) {
2658+
if (K == 1) {
26592659
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
26602660
FOR_UNROLL (short j = 0; j < NSG; j++) {
26612661
const short is = tx*NSG + j;
@@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32(
51045104
for (int64_t sx = x_min; sx < x_max; ++sx) {
51055105
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
51065106
const float w = wx * wy;
5107-
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
5107+
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
51085108
sum += (*src_ptr) * w;
51095109
wsum += w;
51105110
}
@@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32(
52865286
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
52875287
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
52885288

5289-
const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
5289+
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
52905290
sum += (*src_ptr) * wx * wy;
52915291
}
52925292
}
@@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32(
53295329
}
53305330
}
53315331

5332-
kernel void kernel_pad_f32(
5332+
template <typename T>
5333+
kernel void kernel_pad_impl(
53335334
constant ggml_metal_kargs_pad & args,
53345335
device const char * src0,
53355336
device char * dst,
53365337
uint3 tgpig[[threadgroup_position_in_grid]],
53375338
uint3 tpitg[[thread_position_in_threadgroup]],
53385339
uint3 ntg[[threads_per_threadgroup]]) {
5340+
const int32_t i3 = tgpig.z;
5341+
const int32_t i2 = tgpig.y;
5342+
const int32_t k0 = tgpig.x/args.ne1;
5343+
const int32_t i1 = tgpig.x - k0*args.ne1;
53395344

5340-
const int64_t i3 = tgpig.z;
5341-
const int64_t i2 = tgpig.y;
5342-
const int64_t i1 = tgpig.x;
5345+
const int32_t i03 = i3;
5346+
const int32_t i02 = i2;
5347+
const int32_t i01 = i1;
53435348

5344-
const int64_t i03 = i3;
5345-
const int64_t i02 = i2;
5346-
const int64_t i01 = i1;
5347-
5348-
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
5349-
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
5349+
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
5350+
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
53505351

5351-
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
5352-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
5353-
if (i0 < args.ne00) {
5354-
dst_ptr[i0] = src0_ptr[i0];
5355-
} else {
5356-
dst_ptr[i0] = 0.0f;
5357-
}
5352+
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
5353+
const int32_t i0 = k0*1024 + tpitg.x + l0;
5354+
if (i0 >= args.ne0) {
5355+
break;
53585356
}
53595357

5360-
return;
5361-
}
5362-
5363-
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
5364-
dst_ptr[i0] = 0.0f;
5358+
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
5359+
dst_ptr[i0] = src0_ptr[i0];
5360+
} else {
5361+
dst_ptr[i0] = 0.0f;
5362+
}
53655363
}
53665364
}
53675365

5366+
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
5367+
5368+
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
5369+
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
5370+
5371+
// TODO: this is slow - optimize
53685372
kernel void kernel_pad_reflect_1d_f32(
53695373
constant ggml_metal_kargs_pad_reflect_1d & args,
53705374
device const char * src0,
@@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t(
73287332
device const char * src0,
73297333
device char * dst,
73307334
uint3 tgpig[[threadgroup_position_in_grid]],
7331-
ushort tiitg[[thread_index_in_threadgroup]],
7335+
ushort3 tpitg[[thread_position_in_threadgroup]],
73327336
ushort3 ntg[[threads_per_threadgroup]]) {
7333-
const int i03 = tgpig[2];
7334-
const int i02 = tgpig[1];
7335-
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
7336-
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7337+
const int32_t i03 = tgpig[2];
7338+
const int32_t i02 = tgpig[1];
7339+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7340+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7341+
7342+
if (i01 >= args.ne01) {
7343+
return;
7344+
}
73377345

73387346
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
73397347

7340-
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
7341-
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7342-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7343-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7348+
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
7349+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7350+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7351+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
73447352

73457353
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
73467354

7347-
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
7355+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
73487356
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
73497357
dst_data[i00] = (T1) src[0];
73507358
break;
@@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q(
73767384
device const char * src0,
73777385
device char * dst,
73787386
uint3 tgpig[[threadgroup_position_in_grid]],
7379-
ushort tiitg[[thread_index_in_threadgroup]],
7387+
ushort3 tpitg[[thread_position_in_threadgroup]],
73807388
ushort3 ntg[[threads_per_threadgroup]]) {
7381-
const int i03 = tgpig[2];
7382-
const int i02 = tgpig[1];
7383-
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
7384-
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7389+
const int32_t i03 = tgpig[2];
7390+
const int32_t i02 = tgpig[1];
7391+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7392+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7393+
7394+
if (i01 >= args.ne01) {
7395+
return;
7396+
}
73857397

73867398
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
73877399

7388-
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
7389-
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
7390-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
7391-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
7400+
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
7401+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
7402+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
7403+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
73927404

73937405
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
73947406

7395-
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
7407+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
73967408
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
73977409

73987410
quantize_func(src, dst_data[i00]);
@@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32(
74177429
device const char * src0,
74187430
device char * dst,
74197431
uint3 tgpig[[threadgroup_position_in_grid]],
7420-
ushort tiitg[[thread_index_in_threadgroup]],
7432+
ushort3 tpitg[[thread_position_in_threadgroup]],
74217433
ushort3 ntg[[threads_per_threadgroup]]) {
7422-
const int i03 = tgpig[2];
7423-
const int i02 = tgpig[1];
7424-
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
7425-
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7434+
const int32_t i03 = tgpig[2];
7435+
const int32_t i02 = tgpig[1];
7436+
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
7437+
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
7438+
7439+
if (i01 >= args.ne01) {
7440+
return;
7441+
}
74267442

74277443
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
74287444

7429-
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
7430-
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7431-
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7432-
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
7445+
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
7446+
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
7447+
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
7448+
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
74337449

74347450
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
74357451
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
74367452

7437-
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
7453+
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
74387454
T4x4 temp;
74397455
dequantize_func(src_data + i00/nl, i00%nl, temp);
74407456
dst_data[i00] = temp;

src/models/delta-net-base.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,13 +562,13 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
562562
}
563563

564564
const int64_t D = S_v * S_v * H_v;
565-
const int64_t K = (int64_t) cparams.n_rs_seq + 1;
565+
const int64_t K = cparams.n_rs_seq + 1;
566566

567567
// TODO: remove pad + simplify
568-
ggml_tensor * state_in_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
569-
ggml_tensor * state_3d = ggml_pad(ctx0, state_in_3d, 0, K - 1, 0, 0);
568+
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
569+
ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0);
570570

571-
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, state_3d);
571+
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad);
572572
if (n_seq_tokens > 1) {
573573
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
574574
} else {

0 commit comments

Comments
 (0)