Skip to content

Commit e16af37

Browse files
shawngu-quicsrossitto79
authored andcommitted
opencl: generalize Adreno MoE kernels on M (ggml-org#23449)
1 parent 39003fe commit e16af37

18 files changed

Lines changed: 145 additions & 17 deletions

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4693,7 +4693,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
46934693
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
46944694
GGML_UNUSED(backend_ctx);
46954695
int ne01 = tensor->ne[1];
4696-
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
4696+
return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0);
46974697
}
46984698

46994699
inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
@@ -14297,7 +14297,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1429714297
CL_CHECK(status);
1429814298

1429914299
// set thread grid
14300-
global_size[0] = static_cast<size_t>(ne01);
14300+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1430114301
global_size[1] = 4;
1430214302
global_size[2] = static_cast<size_t>(ne20);
1430314303
local_size[1] = 4;
@@ -14513,7 +14513,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1451314513
CL_CHECK(status);
1451414514

1451514515
// set thread grid
14516-
global_size[0] = static_cast<size_t>(ne01);
14516+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1451714517
global_size[1] = 4;
1451814518
global_size[2] = static_cast<size_t>(ne20);
1451914519
local_size[1] = 4;
@@ -14689,7 +14689,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1468914689
CL_CHECK(status);
1469014690

1469114691
// set thread grid
14692-
global_size[0] = static_cast<size_t>(ne01);
14692+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1469314693
global_size[1] = 4;
1469414694
global_size[2] = static_cast<size_t>(ne20);
1469514695
local_size[1] = 4;
@@ -14865,7 +14865,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1486514865
CL_CHECK(status);
1486614866

1486714867
// set thread grid
14868-
global_size[0] = static_cast<size_t>(ne01);
14868+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1486914869
global_size[1] = 4;
1487014870
global_size[2] = static_cast<size_t>(ne20);
1487114871
local_size[1] = 4;
@@ -15118,7 +15118,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1511815118
CL_CHECK(status);
1511915119

1512015120
// set thread grid
15121-
global_size[0] = static_cast<size_t>(ne01);
15121+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1512215122
global_size[1] = 4;
1512315123
global_size[2] = static_cast<size_t>(ne20);
1512415124
local_size[1] = 4;
@@ -15291,7 +15291,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1529115291
CL_CHECK(status);
1529215292

1529315293
// set thread grid
15294-
global_size[0] = static_cast<size_t>(ne01);
15294+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1529515295
global_size[1] = 4;
1529615296
global_size[2] = static_cast<size_t>(ne20);
1529715297
local_size[1] = 4;
@@ -15469,7 +15469,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1546915469
CL_CHECK(status);
1547015470

1547115471
// set thread grid
15472-
global_size[0] = static_cast<size_t>(ne01);
15472+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1547315473
global_size[1] = 4;
1547415474
global_size[2] = static_cast<size_t>(ne20);
1547515475
local_size[1] = 4;
@@ -15644,7 +15644,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
1564415644
CL_CHECK(status);
1564515645

1564615646
// set thread grid
15647-
global_size[0] = static_cast<size_t>(ne01);
15647+
global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64);
1564815648
global_size[1] = 4;
1564915649
global_size[2] = static_cast<size_t>(ne20);
1565015650
local_size[1] = 4;

ggml/src/ggml-opencl/kernels/cvt.cl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ kernel void kernel_convert_block_q4_0_trans4_ns(
220220
uint i01 = get_global_id(0);
221221
uint i02 = get_global_id(2);
222222

223+
if (i01 >= ne01) {
224+
return;
225+
}
226+
223227
uint ne00_blk = ne00 / QK4_0;
224228
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
225229
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -263,6 +267,10 @@ kernel void kernel_restore_block_q4_0_trans4_ns(
263267
uint i01 = get_global_id(0);
264268
uint i02 = get_global_id(2);
265269

270+
if (i01 >= ne01) {
271+
return;
272+
}
273+
266274
uint ne00_blk = ne00 / QK4_0;
267275
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
268276
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -401,6 +409,10 @@ kernel void kernel_convert_block_q4_1_trans4_ns(
401409
uint i01 = get_global_id(0);
402410
uint i02 = get_global_id(2);
403411

412+
if (i01 >= ne01) {
413+
return;
414+
}
415+
404416
uint ne00_blk = ne00 / QK4_1;
405417
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
406418
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -446,6 +458,10 @@ kernel void kernel_restore_block_q4_1_trans4_ns(
446458
uint i01 = get_global_id(0);
447459
uint i02 = get_global_id(2);
448460

461+
if (i01 >= ne01) {
462+
return;
463+
}
464+
449465
uint ne00_blk = ne00 / QK4_1;
450466
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
451467
uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -491,6 +507,10 @@ kernel void kernel_convert_block_q5_0_trans4_ns(
491507
uint i01 = get_global_id(0);
492508
uint i02 = get_global_id(2);
493509

510+
if (i01 >= ne01) {
511+
return;
512+
}
513+
494514
uint ne00_blk = ne00 / QK5_0;
495515
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
496516
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -536,6 +556,10 @@ kernel void kernel_restore_block_q5_0_trans4_ns(
536556
uint i01 = get_global_id(0);
537557
uint i02 = get_global_id(2);
538558

559+
if (i01 >= ne01) {
560+
return;
561+
}
562+
539563
uint ne00_blk = ne00 / QK5_0;
540564
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
541565
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -583,6 +607,10 @@ kernel void kernel_convert_block_q5_1_trans4_ns(
583607
uint i01 = get_global_id(0);
584608
uint i02 = get_global_id(2);
585609

610+
if (i01 >= ne01) {
611+
return;
612+
}
613+
586614
uint ne00_blk = ne00 / QK5_1;
587615
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
588616
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -630,6 +658,10 @@ kernel void kernel_restore_block_q5_1_trans4_ns(
630658
uint i01 = get_global_id(0);
631659
uint i02 = get_global_id(2);
632660

661+
if (i01 >= ne01) {
662+
return;
663+
}
664+
633665
uint ne00_blk = ne00 / QK5_1;
634666
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
635667
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -679,6 +711,10 @@ kernel void kernel_convert_block_q4_k_trans4_ns(
679711
uint i01 = get_global_id(0);
680712
uint i02 = get_global_id(2);
681713

714+
if (i01 >= ne01) {
715+
return;
716+
}
717+
682718
uint ne00_blk = ne00 / QK_K;
683719
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
684720
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -732,6 +768,10 @@ kernel void kernel_restore_block_q4_k_trans4_ns(
732768
uint i01 = get_global_id(0); // row index
733769
uint i02 = get_global_id(2); // batch index
734770

771+
if (i01 >= ne01) {
772+
return;
773+
}
774+
735775
uint ne00_blk = ne00 / QK_K;
736776

737777
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -784,6 +824,10 @@ kernel void kernel_convert_block_q5_k_trans4_ns(
784824
uint i01 = get_global_id(0);
785825
uint i02 = get_global_id(2);
786826

827+
if (i01 >= ne01) {
828+
return;
829+
}
830+
787831
uint ne00_blk = ne00 / QK_K;
788832
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
789833
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -850,6 +894,10 @@ kernel void kernel_restore_block_q5_k_trans4_ns(
850894
uint i01 = get_global_id(0); // row index
851895
uint i02 = get_global_id(2); // batch index
852896

897+
if (i01 >= ne01) {
898+
return;
899+
}
900+
853901
uint ne00_blk = ne00 / QK_K;
854902

855903
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -916,6 +964,10 @@ kernel void kernel_convert_block_q6_k_trans4_ns(
916964
uint i01 = get_global_id(0);
917965
uint i02 = get_global_id(2);
918966

967+
if (i01 >= ne01) {
968+
return;
969+
}
970+
919971
uint ne00_blk = ne00 / QK_K;
920972

921973
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
@@ -993,6 +1045,10 @@ kernel void kernel_restore_block_q6_k_trans4_ns(
9931045
uint i01 = get_global_id(0); // row index
9941046
uint i02 = get_global_id(2); // batch index
9951047

1048+
if (i01 >= ne01) {
1049+
return;
1050+
}
1051+
9961052
uint ne00_blk = ne00 / QK_K;
9971053

9981054
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -1147,6 +1203,10 @@ kernel void kernel_convert_block_mxfp4_trans4_ns(
11471203
uint i01 = get_global_id(0);
11481204
uint i02 = get_global_id(2);
11491205

1206+
if (i01 >= ne01) {
1207+
return;
1208+
}
1209+
11501210
uint ne00_blk = ne00 / QK_MXFP4;
11511211
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
11521212
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
@@ -1190,6 +1250,10 @@ kernel void kernel_restore_block_mxfp4_trans4_ns(
11901250
uint i01 = get_global_id(0);
11911251
uint i02 = get_global_id(2);
11921252

1253+
if (i01 >= ne01) {
1254+
return;
1255+
}
1256+
11931257
uint ne00_blk = ne00 / QK_MXFP4;
11941258
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
11951259
uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns(
163163
uint block_id_n = get_global_id(2); // n_tile
164164

165165
// Boundary check
166-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
166+
if (block_id_n >= total_tiles[0]) {
167167
return;
168168
}
169169

@@ -248,6 +248,10 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns(
248248
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
249249
}
250250

251+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
252+
return;
253+
}
254+
251255
// Load poster router and share in LM
252256
__local uint out_idx[TILESIZE_N];
253257

ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ kernel void kernel_gemm_moe_q4_0_f32_ns(
115115
uint block_id_n = get_global_id(2); // n_tile
116116

117117
// Boundary check
118-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
118+
if (block_id_n >= total_tiles[0]) {
119119
return;
120120
}
121121

@@ -198,6 +198,10 @@ kernel void kernel_gemm_moe_q4_0_f32_ns(
198198
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
199199
}
200200

201+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
202+
return;
203+
}
204+
201205
// Load poster router and share in LM
202206
__local uint out_idx[TILESIZE_N];
203207

ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q4_1_f32_ns(
116116
uint block_id_n = get_global_id(2); // n_tile
117117

118118
// Boundary check
119-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
119+
if (block_id_n >= total_tiles[0]) {
120120
return;
121121
}
122122

@@ -200,6 +200,10 @@ kernel void kernel_gemm_moe_q4_1_f32_ns(
200200
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
201201
}
202202

203+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
204+
return;
205+
}
206+
203207
// Load poster router and share in LM
204208
__local uint out_idx[TILESIZE_N];
205209

ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ kernel void kernel_gemm_moe_q4_k_f32_ns(
133133
uint block_id_n = get_global_id(2); // n_tile
134134

135135
// Boundary check
136-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
136+
if (block_id_n >= total_tiles[0]) {
137137
return;
138138
}
139139

@@ -225,6 +225,10 @@ kernel void kernel_gemm_moe_q4_k_f32_ns(
225225
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
226226
}
227227

228+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
229+
return;
230+
}
231+
228232
// Load post router and share in LM
229233
__local uint out_idx[TILESIZE_N];
230234

ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q5_0_f32_ns(
116116
uint block_id_n = get_global_id(2); // n_tile
117117

118118
// Boundary check
119-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
119+
if (block_id_n >= total_tiles[0]) {
120120
return;
121121
}
122122

@@ -202,6 +202,10 @@ kernel void kernel_gemm_moe_q5_0_f32_ns(
202202
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
203203
}
204204

205+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
206+
return;
207+
}
208+
205209
// Load poster router and share in LM
206210
__local uint out_idx[TILESIZE_N];
207211

ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q5_1_f32_ns(
117117
uint block_id_n = get_global_id(2); // n_tile
118118

119119
// Boundary check
120-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
120+
if (block_id_n >= total_tiles[0]) {
121121
return;
122122
}
123123

@@ -204,6 +204,10 @@ kernel void kernel_gemm_moe_q5_1_f32_ns(
204204
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
205205
}
206206

207+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
208+
return;
209+
}
210+
207211
// Load poster router and share in LM
208212
__local uint out_idx[TILESIZE_N];
209213

ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ kernel void kernel_gemm_moe_q5_k_f32_ns(
134134
uint block_id_n = get_global_id(2); // n_tile
135135

136136
// Boundary check
137-
if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) {
137+
if (block_id_n >= total_tiles[0]) {
138138
return;
139139
}
140140

@@ -230,6 +230,10 @@ kernel void kernel_gemm_moe_q5_k_f32_ns(
230230
dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16);
231231
}
232232

233+
if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) {
234+
return;
235+
}
236+
233237
// Load post router and share in LM
234238
__local uint out_idx[TILESIZE_N];
235239

0 commit comments

Comments
 (0)