Skip to content

Commit 459b75b

Browse files
committed
templatize multi_token_path
1 parent 3183b72 commit 459b75b

2 files changed

Lines changed: 82 additions & 38 deletions

File tree

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "mmvf.cuh"
55
#include "convert.cuh"
66

7-
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
7+
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
88
static __global__ void mul_mat_vec_f(
99
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
1010
const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
@@ -14,19 +14,38 @@ static __global__ void mul_mat_vec_f(
1414
const int row = blockIdx.x;
1515
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
1616
const int channel_dst = blockIdx.y;
17-
const int token_idx = ids ? blockIdx.z : 0;
18-
const int channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
19-
const int channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
20-
const int sample_dst = ids ? 0 : blockIdx.z;
17+
const int tid = threadIdx.x;
18+
19+
int token_idx;
20+
int channel_x;
21+
int channel_y;
22+
int sample_dst;
23+
24+
if constexpr (is_multi_token_id) {
25+
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
26+
token_idx = blockIdx.z;
27+
channel_x = ids[channel_dst + token_idx * ids_stride];
28+
channel_y = fastmodulo(channel_dst, nchannels_y);
29+
sample_dst = 0;
30+
} else {
31+
token_idx = ids ? blockIdx.z : 0;
32+
channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
33+
channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
34+
sample_dst = ids ? 0 : blockIdx.z;
35+
}
36+
2137
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
2238
const int sample_y = sample_dst;
23-
const int tid = threadIdx.x;
2439

2540
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2641

2742
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
28-
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y + token_idx*stride_col_y2*2;
29-
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst + token_idx*stride_col_dst;
43+
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
44+
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
45+
if constexpr (is_multi_token_id) {
46+
y += token_idx*stride_col_y2*2;
47+
dst += token_idx*stride_col_dst;
48+
}
3049

3150
bool use_gate = false;
3251
bool use_bias = false;
@@ -354,7 +373,7 @@ static __global__ void mul_mat_vec_f(
354373
}
355374
}
356375

357-
template<typename T, typename type_acc, int ncols_dst, int block_size>
376+
template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
358377
static void mul_mat_vec_f_switch_fusion(
359378
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
360379
const int64_t ncols, const uint3 nchannels_y,
@@ -366,7 +385,7 @@ static void mul_mat_vec_f_switch_fusion(
366385
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
367386
if constexpr (ncols_dst == 1) {
368387
if (has_fusion) {
369-
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
388+
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
370389
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
371390
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
372391
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -376,14 +395,14 @@ static void mul_mat_vec_f_switch_fusion(
376395

377396
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
378397

379-
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
398+
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
380399
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
381400
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
382401
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
383402

384403
}
385404

386-
template <typename T, typename type_acc, int ncols_dst>
405+
template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
387406
void launch_mul_mat_vec_f_cuda(
388407
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
389408
const int64_t ncols, const int64_t nrows,
@@ -425,49 +444,49 @@ void launch_mul_mat_vec_f_cuda(
425444
const dim3 block_dims(block_size_best, 1, 1);
426445
switch (block_size_best) {
427446
case 32: {
428-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
447+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
429448
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
430449
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
431450
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
432451
} break;
433452
case 64: {
434-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
453+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
435454
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
436455
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437456
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
438457
} break;
439458
case 96: {
440-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
459+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
441460
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
442461
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
443462
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
444463
} break;
445464
case 128: {
446-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
465+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
447466
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
448467
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
449468
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
450469
} break;
451470
case 160: {
452-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
471+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
453472
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
454473
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
455474
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
456475
} break;
457476
case 192: {
458-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
477+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
459478
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
460479
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461480
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
462481
} break;
463482
case 224: {
464-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
483+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
465484
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
466485
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
467486
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
468487
} break;
469488
case 256: {
470-
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
489+
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
471490
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
472491
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
473492
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
@@ -490,8 +509,19 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
490509

491510
const bool has_ids = ids != nullptr;
492511

512+
if (has_ids && ncols_dst > 1) {
513+
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
514+
constexpr int c_ncols_dst = 1;
515+
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
516+
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
517+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
518+
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
519+
ncols_dst, ids_stride, stream);
520+
return;
521+
}
522+
493523
if (has_ids) {
494-
// note: batching ncols_dst is not possible because tokens use different experts, so we use ncols_dst = 1 and iterate via blockIdx.z
524+
// Single-token MUL_MAT_ID path
495525
constexpr int c_ncols_dst = 1;
496526
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
497527
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
137137
return 1;
138138
}
139139

140-
// tell the compiler to use as many registers as it wants, see nwarps definition below
141-
template <ggml_type type, int ncols_dst, bool has_fusion>
140+
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
142141
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
143142
static __global__ void mul_mat_vec_q(
144143
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@@ -163,15 +162,23 @@ static __global__ void mul_mat_vec_q(
163162
const int blocks_per_row_x = ncols_x / qk;
164163
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
165164

166-
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
167165
const uint32_t channel_dst = blockIdx.y;
168-
const uint32_t token_idx = blockIdx.z;
169-
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio);
170-
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
171-
uint32_t sample_dst = blockIdx.z;
172166

173-
if constexpr (ncols_dst == 1) {
174-
sample_dst *= !ids_stride; // sample_dst for ids is 0
167+
uint32_t token_idx = 0;
168+
uint32_t channel_x;
169+
uint32_t channel_y;
170+
uint32_t sample_dst;
171+
172+
if constexpr (is_multi_token_id) {
173+
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
174+
token_idx = blockIdx.z;
175+
channel_x = ids[channel_dst + token_idx * ids_stride];
176+
channel_y = fastmodulo(channel_dst, nchannels_y);
177+
sample_dst = 0;
178+
} else {
179+
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
180+
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
181+
sample_dst = blockIdx.z;
175182
}
176183

177184
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
@@ -228,7 +235,10 @@ static __global__ void mul_mat_vec_q(
228235
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
229236
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
230237

231-
const block_q8_1 * y = ((const block_q8_1 *) vy) + token_idx*stride_col_y + sample_y*stride_sample_y + channel_y*stride_channel_y;
238+
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
239+
if constexpr (is_multi_token_id) {
240+
y += token_idx*stride_col_y;
241+
}
232242
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
233243

234244
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -280,7 +290,11 @@ static __global__ void mul_mat_vec_q(
280290
return;
281291
}
282292

283-
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0;
293+
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
294+
295+
if constexpr (is_multi_token_id) {
296+
dst += token_idx*stride_col_dst;
297+
}
284298

285299
// sum up partial sums and write back result
286300
#pragma unroll
@@ -350,7 +364,7 @@ static std::pair<dim3, dim3> calc_launch_params(
350364
return {block_nums, block_dims};
351365
}
352366

353-
template<ggml_type type, int c_ncols_dst>
367+
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
354368
static void mul_mat_vec_q_switch_fusion(
355369
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
356370
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -363,7 +377,7 @@ static void mul_mat_vec_q_switch_fusion(
363377
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
364378
if constexpr (c_ncols_dst == 1) {
365379
if (has_fusion) {
366-
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
380+
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
367381
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
368382
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
369383
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -373,7 +387,7 @@ static void mul_mat_vec_q_switch_fusion(
373387

374388
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
375389

376-
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
390+
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
377391
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
378392
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
379393
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -403,11 +417,11 @@ static void mul_mat_vec_q_switch_ncols_dst(
403417
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
404418
const bool has_ids = ids != nullptr;
405419

406-
if (has_ids) {
407-
// note: batching ncols_dst is not possible because token use different experts, so we use ncols_dst = 1 and iterate via blockIdx.z
420+
if (has_ids && ncols_dst > 1) {
421+
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
408422
constexpr int c_ncols_dst = 1;
409423
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
410-
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
424+
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
411425
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
412426
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
413427
dims.first, dims.second, 0, ids_stride, stream);

0 commit comments

Comments
 (0)