Skip to content

Commit e77056f

Browse files
authored
CUDA: use fastdiv for batch index split in get_rows (ggml-org#22650)
1 parent 935a340 commit e77056f

1 file changed

Lines changed: 20 additions & 10 deletions

File tree

ggml/src/ggml-cuda/getrows.cu

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
66
static __global__ void k_get_rows(
77
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
88
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
9-
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
9+
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
1010
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
1111
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
1212
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
1313

14-
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
14+
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
1515
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
1616
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
1717
const int i10 = blockIdx.x;
18-
const int i11 = z / ne12; // TODO fastdiv
19-
const int i12 = z % ne12;
18+
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
19+
const int i11 = dm.x;
20+
const int i12 = dm.y;
2021

2122
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
2223

@@ -42,17 +43,18 @@ template<typename src0_t, typename dst_t>
4243
static __global__ void k_get_rows_float(
4344
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
4445
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
45-
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
46+
/*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/
4647
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
4748
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
4849
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
4950

50-
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
51+
for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) {
5152
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
5253
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
5354
const int i10 = blockIdx.x;
54-
const int i11 = z / ne12; // TODO fastdiv
55-
const int i12 = z % ne12;
55+
const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv);
56+
const int i11 = dm.x;
57+
const int i12 = dm.y;
5658

5759
if (i00 >= ne00) {
5860
return;
@@ -115,10 +117,14 @@ static void get_rows_cuda_q(
115117

116118
GGML_ASSERT(ne00 % 2 == 0);
117119

120+
GGML_ASSERT(ne12 > 0);
121+
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
122+
const uint3 ne12_fdv = init_fastdiv_values(ne12);
123+
118124
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
119125
src0_d, src1_d, dst_d,
120126
ne00, /*ne01, ne02, ne03,*/
121-
/*ne10,*/ ne11, ne12, /*ne13,*/
127+
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
122128
/* s0,*/ s1, s2, s3,
123129
/* nb00,*/ nb01, nb02, nb03,
124130
s10, s11, s12/*, s13*/);
@@ -146,10 +152,14 @@ static void get_rows_cuda_float(
146152
const size_t s12 = nb12 / sizeof(int32_t);
147153
// const size_t s13 = nb13 / sizeof(int32_t);
148154

155+
GGML_ASSERT(ne12 > 0);
156+
GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12);
157+
const uint3 ne12_fdv = init_fastdiv_values(ne12);
158+
149159
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
150160
src0_d, src1_d, dst_d,
151161
ne00, /*ne01, ne02, ne03,*/
152-
/*ne10,*/ ne11, ne12, /*ne13,*/
162+
/*ne10,*/ ne11, ne12_fdv, /*ne13,*/
153163
/* s0,*/ s1, s2, s3,
154164
/* nb00,*/ nb01, nb02, nb03,
155165
s10, s11, s12/*, s13*/);

0 commit comments

Comments
 (0)