Skip to content

Commit e10bfbb

Browse files
committed
templatize mul_mat_id
1 parent 89198d7 commit e10bfbb

1 file changed

Lines changed: 112 additions & 49 deletions

File tree

ggml/src/ggml-cuda/mmf.cu

Lines changed: 112 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using namespace ggml_cuda_mma;
77

88
#define MMF_ROWS_PER_BLOCK 32
99

10-
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
10+
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
1111
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
1212
static __global__ void mul_mat_f(
1313
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -27,9 +27,8 @@ static __global__ void mul_mat_f(
2727

2828
const int row0 = blockIdx.x * rows_per_block;
2929

30-
const bool has_ids = ids != nullptr;
31-
const int expert_idx = has_ids ? blockIdx.y : 0;
32-
const int channel_dst = has_ids ? 0 : blockIdx.y;
30+
const int expert_idx = has_ids ? blockIdx.y : 0;
31+
const int channel_dst = has_ids ? 0 : blockIdx.y;
3332

3433
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
3534
const int channel_y = channel_dst;
@@ -47,13 +46,13 @@ static __global__ void mul_mat_f(
4746

4847
char * shmem_base = data_mmv;
4948
int * slot_map = (int *) shmem_base;
50-
char * compute_base = has_ids ? (shmem_base + cols_per_block * sizeof(int32_t)) : shmem_base;
49+
char * compute_base = has_ids ? (shmem_base + cols_per_block * sizeof(int)) : shmem_base;
5150

5251
tile_C C[ntA][ntB];
5352

5453
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
5554

56-
if (has_ids) {
55+
if constexpr (has_ids) {
5756
__shared__ int has_any;
5857
if (threadIdx.y == 0) {
5958
int local_has_any = 0;
@@ -100,7 +99,7 @@ static __global__ void mul_mat_f(
10099
for (int j0 = 0; j0 < tile_B::I; ++j0) {
101100
const int j = j0 + itB*tile_B::I;
102101

103-
if (!has_ids) {
102+
if constexpr (!has_ids) {
104103
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
105104
} else {
106105
float val = 0.0f;
@@ -118,7 +117,7 @@ static __global__ void mul_mat_f(
118117
for (int j0 = 0; j0 < tile_B::I; ++j0) {
119118
const int j = j0 + itB*tile_B::I;
120119

121-
if (!has_ids) {
120+
if constexpr (!has_ids) {
122121
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
123122
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
124123
} else {
@@ -188,7 +187,7 @@ static __global__ void mul_mat_f(
188187
sum += buf_iw[j*kiw + i];
189188
}
190189

191-
if (!has_ids) {
190+
if constexpr (!has_ids) {
192191
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
193192
} else {
194193
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
@@ -255,60 +254,124 @@ static void mul_mat_f_cuda(
255254

256255
switch (nwarps_best) {
257256
case 1: {
258-
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared_total, stream>>>
259-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
260-
stride_col_id, stride_row_id,
261-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
262-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
257+
if (ids) {
258+
mul_mat_f<T, rows_per_block, cols_per_block, 1, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
259+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
260+
stride_col_id, stride_row_id,
261+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
262+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
263+
} else {
264+
mul_mat_f<T, rows_per_block, cols_per_block, 1, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
265+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
266+
stride_col_id, stride_row_id,
267+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
268+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
269+
}
263270
} break;
264271
case 2: {
265-
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared_total, stream>>>
266-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
267-
stride_col_id, stride_row_id,
268-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
269-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
272+
if (ids) {
273+
mul_mat_f<T, rows_per_block, cols_per_block, 2, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
274+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
275+
stride_col_id, stride_row_id,
276+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
277+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
278+
} else {
279+
mul_mat_f<T, rows_per_block, cols_per_block, 2, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
280+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281+
stride_col_id, stride_row_id,
282+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
283+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
284+
}
270285
} break;
271286
case 3: {
272-
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared_total, stream>>>
273-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
274-
stride_col_id, stride_row_id,
275-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287+
if (ids) {
288+
mul_mat_f<T, rows_per_block, cols_per_block, 3, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
289+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
290+
stride_col_id, stride_row_id,
291+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
292+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
293+
} else {
294+
mul_mat_f<T, rows_per_block, cols_per_block, 3, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
295+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
296+
stride_col_id, stride_row_id,
297+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
298+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
299+
}
277300
} break;
278301
case 4: {
279-
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared_total, stream>>>
280-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281-
stride_col_id, stride_row_id,
282-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
283-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
302+
if (ids) {
303+
mul_mat_f<T, rows_per_block, cols_per_block, 4, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
304+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
305+
stride_col_id, stride_row_id,
306+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
307+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
308+
} else {
309+
mul_mat_f<T, rows_per_block, cols_per_block, 4, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
310+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
311+
stride_col_id, stride_row_id,
312+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
313+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
314+
}
284315
} break;
285316
case 5: {
286-
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared_total, stream>>>
287-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
288-
stride_col_id, stride_row_id,
289-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
290-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317+
if (ids) {
318+
mul_mat_f<T, rows_per_block, cols_per_block, 5, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
319+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
320+
stride_col_id, stride_row_id,
321+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
322+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
323+
} else {
324+
mul_mat_f<T, rows_per_block, cols_per_block, 5, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
325+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
326+
stride_col_id, stride_row_id,
327+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
328+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
329+
}
291330
} break;
292331
case 6: {
293-
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared_total, stream>>>
294-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
295-
stride_col_id, stride_row_id,
296-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
297-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
332+
if (ids) {
333+
mul_mat_f<T, rows_per_block, cols_per_block, 6, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
334+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
335+
stride_col_id, stride_row_id,
336+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
337+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
338+
} else {
339+
mul_mat_f<T, rows_per_block, cols_per_block, 6, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
340+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
341+
stride_col_id, stride_row_id,
342+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
343+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
344+
}
298345
} break;
299346
case 7: {
300-
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared_total, stream>>>
301-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
302-
stride_col_id, stride_row_id,
303-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
304-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
347+
if (ids) {
348+
mul_mat_f<T, rows_per_block, cols_per_block, 7, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
349+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
350+
stride_col_id, stride_row_id,
351+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
352+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
353+
} else {
354+
mul_mat_f<T, rows_per_block, cols_per_block, 7, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
355+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
356+
stride_col_id, stride_row_id,
357+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
358+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
359+
}
305360
} break;
306361
case 8: {
307-
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared_total, stream>>>
308-
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
309-
stride_col_id, stride_row_id,
310-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
311-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
362+
if (ids) {
363+
mul_mat_f<T, rows_per_block, cols_per_block, 8, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
364+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
365+
stride_col_id, stride_row_id,
366+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
367+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
368+
} else {
369+
mul_mat_f<T, rows_per_block, cols_per_block, 8, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
370+
(x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
371+
stride_col_id, stride_row_id,
372+
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
373+
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
374+
}
312375
} break;
313376
default: {
314377
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)