@@ -6,17 +6,18 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
66static __global__ void k_get_rows (
77 const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
88 const int64_t ne00, /* const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
9- /* const int64_t ne10,*/ const int64_t ne11, const int64_t ne12 , /* const int64_t ne13,*/
9+ /* const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv , /* const int64_t ne13,*/
1010 /* const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
1111 /* const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
1212 const size_t s10, const size_t s11, const size_t s12/* , const size_t s13*/ ) {
1313
14- for (int64_t z = blockIdx .z ; z < ne11*ne12 ; z += gridDim .z ) {
14+ for (int64_t z = blockIdx .z ; z < ne11*( int64_t )ne12_fdv. z ; z += gridDim .z ) {
1515 for (int64_t i00 = 2 *(blockIdx .y *blockDim .x + threadIdx .x ); i00 < ne00; i00 += gridDim .y *blockDim .x ) {
1616 // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
1717 const int i10 = blockIdx .x ;
18- const int i11 = z / ne12; // TODO fastdiv
19- const int i12 = z % ne12;
18+ const uint2 dm = fast_div_modulo ((uint32_t )z, ne12_fdv);
19+ const int i11 = dm.x ;
20+ const int i12 = dm.y ;
2021
2122 const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
2223
@@ -42,17 +43,18 @@ template<typename src0_t, typename dst_t>
4243static __global__ void k_get_rows_float (
4344 const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
4445 const int64_t ne00, /* const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
45- /* const int64_t ne10,*/ const int64_t ne11, const int64_t ne12 , /* const int64_t ne13,*/
46+ /* const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv , /* const int64_t ne13,*/
4647 /* const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
4748 /* const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
4849 const size_t s10, const size_t s11, const size_t s12/* , const size_t s13*/ ) {
4950
50- for (int64_t z = blockIdx .z ; z < ne11*ne12 ; z += gridDim .z ) {
51+ for (int64_t z = blockIdx .z ; z < ne11*( int64_t )ne12_fdv. z ; z += gridDim .z ) {
5152 for (int64_t i00 = blockIdx .y *blockDim .x + threadIdx .x ; i00 < ne00; i00 += gridDim .y *blockDim .x ) {
5253 // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
5354 const int i10 = blockIdx .x ;
54- const int i11 = z / ne12; // TODO fastdiv
55- const int i12 = z % ne12;
55+ const uint2 dm = fast_div_modulo ((uint32_t )z, ne12_fdv);
56+ const int i11 = dm.x ;
57+ const int i12 = dm.y ;
5658
5759 if (i00 >= ne00) {
5860 return ;
@@ -115,10 +117,14 @@ static void get_rows_cuda_q(
115117
116118 GGML_ASSERT (ne00 % 2 == 0 );
117119
120+ GGML_ASSERT (ne12 > 0 );
121+ GGML_ASSERT (ne11 <= std::numeric_limits<uint32_t >::max () / ne12);
122+ const uint3 ne12_fdv = init_fastdiv_values (ne12);
123+
118124 k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
119125 src0_d, src1_d, dst_d,
120126 ne00, /* ne01, ne02, ne03,*/
121- /* ne10,*/ ne11, ne12 , /* ne13,*/
127+ /* ne10,*/ ne11, ne12_fdv , /* ne13,*/
122128 /* s0,*/ s1, s2, s3,
123129 /* nb00,*/ nb01, nb02, nb03,
124130 s10, s11, s12/* , s13*/ );
@@ -146,10 +152,14 @@ static void get_rows_cuda_float(
146152 const size_t s12 = nb12 / sizeof (int32_t );
147153 // const size_t s13 = nb13 / sizeof(int32_t);
148154
155+ GGML_ASSERT (ne12 > 0 );
156+ GGML_ASSERT (ne11 <= std::numeric_limits<uint32_t >::max () / ne12);
157+ const uint3 ne12_fdv = init_fastdiv_values (ne12);
158+
149159 k_get_rows_float<<<block_nums, block_dims, 0 , stream>>> (
150160 src0_d, src1_d, dst_d,
151161 ne00, /* ne01, ne02, ne03,*/
152- /* ne10,*/ ne11, ne12 , /* ne13,*/
162+ /* ne10,*/ ne11, ne12_fdv , /* ne13,*/
153163 /* s0,*/ s1, s2, s3,
154164 /* nb00,*/ nb01, nb02, nb03,
155165 s10, s11, s12/* , s13*/ );
0 commit comments