Skip to content

Commit 66c4f9d

Browse files
iacopPBKJohannesGaessleriacopPBK
authored
ggml-cuda: ds_read_b128 for q4_0 and q4_1 mmq kernels (#21168)
* ds_read_b128 for q4_0 and q4_1 mmq kernels Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both. * Vectorized lds load update: used ggml_cuda_get_max_cpy_bytes and ggml_cuda_memcpy_1 functions for generic implementation * Explicit for loop in mmq, renamed vec into tmp * Fixed max_cpy usage in the loading loop * Fixed typo in q4_1 kernel * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Renoved trailing white line 500 * Update mmq.cuh removed other whitelines * Remove trailing whitespaces --------- Co-authored-by: iacopPBK <iacopPBK@users.noreply.github.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: iacopPBK <iacop@deneb.com>
1 parent 93bdc61 commit 66c4f9d

1 file changed

Lines changed: 27 additions & 10 deletions

File tree

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,17 +386,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
386386
#pragma unroll
387387
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
388388
const int i = i0 + threadIdx.x;
389-
390389
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
391390

392391
int u[2*VDR_Q4_0_Q8_1_MMQ];
393392

394-
#pragma unroll
395-
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
396-
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
397-
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
393+
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
394+
constexpr int mcpy_int = max_cpy / sizeof(int);
395+
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
396+
397+
int tmp0[4], tmp1[4];
398+
399+
#pragma unroll
400+
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
401+
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
402+
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]);
398403
}
399404

405+
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
406+
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
407+
400408
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
401409
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
402410
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -489,17 +497,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
489497
#pragma unroll
490498
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
491499
const int i = i0 + threadIdx.x;
492-
493500
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
494501

495502
int u[2*VDR_Q4_1_Q8_1_MMQ];
496503

497-
#pragma unroll
498-
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
499-
u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
500-
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
504+
constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes();
505+
constexpr int mcpy_int = max_cpy / sizeof(int);
506+
static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ");
507+
508+
int tmp0[4], tmp1[4];
509+
510+
#pragma unroll
511+
for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) {
512+
ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] );
513+
ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]);
501514
}
502515

516+
u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3];
517+
u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3];
518+
503519
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
504520
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
505521
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
@@ -4170,3 +4186,4 @@ void ggml_cuda_op_mul_mat_q(
41704186
const int64_t src1_padded_row_size, cudaStream_t stream);
41714187

41724188
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
4189+

0 commit comments

Comments
 (0)