Skip to content

Commit 5416217

Browse files
committed
Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids
1 parent e10bfbb commit 5416217

4 files changed

Lines changed: 61 additions & 115 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2110,7 +2110,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21102110
return;
21112111
}
21122112

2113-
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], ids)) {
2113+
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
21142114
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
21152115
return;
21162116
}

ggml/src/ggml-cuda/mmf.cu

Lines changed: 58 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ static __global__ void mul_mat_f(
4646

4747
char * shmem_base = data_mmv;
4848
int * slot_map = (int *) shmem_base;
49-
char * compute_base = has_ids ? (shmem_base + cols_per_block * sizeof(int)) : shmem_base;
49+
char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
5050

5151
tile_C C[ntA][ntB];
5252

@@ -206,6 +206,28 @@ static __global__ void mul_mat_f(
206206
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
207207
}
208208

209+
template <typename T, int cols_per_block, int nwarps>
210+
static inline void mul_mat_f_switch_ids(
211+
const T * x, const float * y, const int32_t * ids, float * dst,
212+
const int64_t ncols_x, const int64_t nchannels_y, const int64_t nchannels_x, const int64_t nchannels_dst,
213+
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
214+
const int64_t stride_col_id, const int64_t stride_row_id,
215+
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
216+
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
217+
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
218+
if (ids) {
219+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
220+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
221+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
222+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
223+
} else {
224+
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
225+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
226+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
227+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
228+
}
229+
}
230+
209231
template <typename T, int cols_per_block>
210232
static void mul_mat_f_cuda(
211233
const T * x, const float * y, const int32_t * ids, float * dst,
@@ -245,7 +267,7 @@ static void mul_mat_f_cuda(
245267
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
246268
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
247269
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
248-
const int nbytes_slotmap = ids ? (int)(cols_per_block * sizeof(int32_t)) : 0;
270+
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
249271
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
250272
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
251273

@@ -254,124 +276,52 @@ static void mul_mat_f_cuda(
254276

255277
switch (nwarps_best) {
256278
case 1: {
257-
if (ids) {
258-
mul_mat_f<T, rows_per_block, cols_per_block, 1, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
259-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
260-
stride_col_id, stride_row_id,
261-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
262-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
263-
} else {
264-
mul_mat_f<T, rows_per_block, cols_per_block, 1, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
265-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
266-
stride_col_id, stride_row_id,
267-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
268-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
269-
}
279+
mul_mat_f_switch_ids<T, cols_per_block, 1>(
280+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
282+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
270283
} break;
271284
case 2: {
272-
if (ids) {
273-
mul_mat_f<T, rows_per_block, cols_per_block, 2, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
274-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
275-
stride_col_id, stride_row_id,
276-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
277-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
278-
} else {
279-
mul_mat_f<T, rows_per_block, cols_per_block, 2, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
280-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281-
stride_col_id, stride_row_id,
282-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
283-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
284-
}
285+
mul_mat_f_switch_ids<T, cols_per_block, 2>(
286+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
287+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
288+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
285289
} break;
286290
case 3: {
287-
if (ids) {
288-
mul_mat_f<T, rows_per_block, cols_per_block, 3, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
289-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
290-
stride_col_id, stride_row_id,
291-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
292-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
293-
} else {
294-
mul_mat_f<T, rows_per_block, cols_per_block, 3, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
295-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
296-
stride_col_id, stride_row_id,
297-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
298-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
299-
}
291+
mul_mat_f_switch_ids<T, cols_per_block, 3>(
292+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
293+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
294+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
300295
} break;
301296
case 4: {
302-
if (ids) {
303-
mul_mat_f<T, rows_per_block, cols_per_block, 4, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
304-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
305-
stride_col_id, stride_row_id,
306-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
307-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
308-
} else {
309-
mul_mat_f<T, rows_per_block, cols_per_block, 4, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
310-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
311-
stride_col_id, stride_row_id,
312-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
313-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
314-
}
297+
mul_mat_f_switch_ids<T, cols_per_block, 4>(
298+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
299+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
300+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
315301
} break;
316302
case 5: {
317-
if (ids) {
318-
mul_mat_f<T, rows_per_block, cols_per_block, 5, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
319-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
320-
stride_col_id, stride_row_id,
321-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
322-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
323-
} else {
324-
mul_mat_f<T, rows_per_block, cols_per_block, 5, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
325-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
326-
stride_col_id, stride_row_id,
327-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
328-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
329-
}
303+
mul_mat_f_switch_ids<T, cols_per_block, 5>(
304+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
305+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
330307
} break;
331308
case 6: {
332-
if (ids) {
333-
mul_mat_f<T, rows_per_block, cols_per_block, 6, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
334-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
335-
stride_col_id, stride_row_id,
336-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
337-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
338-
} else {
339-
mul_mat_f<T, rows_per_block, cols_per_block, 6, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
340-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
341-
stride_col_id, stride_row_id,
342-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
343-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
344-
}
309+
mul_mat_f_switch_ids<T, cols_per_block, 6>(
310+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
311+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
312+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
345313
} break;
346314
case 7: {
347-
if (ids) {
348-
mul_mat_f<T, rows_per_block, cols_per_block, 7, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
349-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
350-
stride_col_id, stride_row_id,
351-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
352-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
353-
} else {
354-
mul_mat_f<T, rows_per_block, cols_per_block, 7, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
355-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
356-
stride_col_id, stride_row_id,
357-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
358-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
359-
}
315+
mul_mat_f_switch_ids<T, cols_per_block, 7>(
316+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
317+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
318+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
360319
} break;
361320
case 8: {
362-
if (ids) {
363-
mul_mat_f<T, rows_per_block, cols_per_block, 8, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
364-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
365-
stride_col_id, stride_row_id,
366-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
367-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
368-
} else {
369-
mul_mat_f<T, rows_per_block, cols_per_block, 8, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
370-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
371-
stride_col_id, stride_row_id,
372-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
373-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
374-
}
321+
mul_mat_f_switch_ids<T, cols_per_block, 8>(
322+
x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
323+
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
324+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
375325
} break;
376326
default: {
377327
GGML_ABORT("fatal error");
@@ -559,7 +509,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
559509
}
560510
}
561511

562-
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, const ggml_tensor * ids) {
512+
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
563513

564514
if (ggml_is_quantized(type)) {
565515
return false;
@@ -571,11 +521,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
571521
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
572522
return false;
573523
}
574-
if (!ids && src1_ncols > 16) {
575-
return false;
576-
}
577-
578-
if (ids && src1_ncols > 16) {
524+
if (src1_ncols > 16) {
579525
return false;
580526
}
581527

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
44

5-
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, const ggml_tensor * ids = nullptr);
5+
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6261,7 +6261,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
62616261
for (int n_mats : {4, 8}) {
62626262
for (int n_used : {1, 2, 4}) {
62636263
for (bool b : {false, true}) {
6264-
for (int n : {1, 4, 32, 129}) {
6264+
for (int n : {1, 4, 5, 32, 129}) {
62656265
int m = 512;
62666266
int k = 256;
62676267
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));

0 commit comments

Comments
 (0)