Skip to content

Commit e380897

Browse files
committed
cuda: Q2_0
1 parent 0eed534 commit e380897

12 files changed

Lines changed: 235 additions & 1 deletion

File tree

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
1818
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
1919
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
20+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
2021
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
2122
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
2223
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -83,6 +84,7 @@
8384
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
8485
// quants.c
8586
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
87+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
8688
// repack.cpp
8789
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
8890
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
@@ -114,6 +116,7 @@
114116
#define quantize_row_q8_K_generic quantize_row_q8_K
115117
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
116118
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
119+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
117120
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
118121
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
119122
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
@@ -163,6 +166,7 @@
163166
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
164167
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
165168
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
169+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
166170
// repack.cpp
167171
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
168172
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -204,6 +208,7 @@
204208
// quants.c
205209
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
206210
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
211+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
207212
// repack.cpp
208213
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
209214
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@@ -245,6 +250,7 @@
245250
#define quantize_row_q8_K_generic quantize_row_q8_K
246251
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
247252
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
253+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
248254
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
249255
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
250256
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -309,6 +315,7 @@
309315
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
310316
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
311317
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
318+
#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0
312319
// repack.cpp
313320
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
314321
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q1_0> {
931931
static constexpr int qi = QI1_0;
932932
};
933933

934+
template<>
935+
struct ggml_cuda_type_traits<GGML_TYPE_Q2_0> {
936+
static constexpr int qk = QK2_0;
937+
static constexpr int qr = QR2_0;
938+
static constexpr int qi = QI2_0;
939+
};
940+
934941
template<>
935942
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
936943
static constexpr int qk = QK4_0;

ggml/src/ggml-cuda/convert.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
713713
switch (type) {
714714
case GGML_TYPE_Q1_0:
715715
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
716+
case GGML_TYPE_Q2_0:
717+
return dequantize_block_cont_cuda<QK2_0, QR2_0, dequantize_q2_0>;
716718
case GGML_TYPE_Q4_0:
717719
return dequantize_row_q4_0_cuda;
718720
case GGML_TYPE_Q4_1:
@@ -771,6 +773,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
771773
switch (type) {
772774
case GGML_TYPE_Q1_0:
773775
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
776+
case GGML_TYPE_Q2_0:
777+
return dequantize_block_cont_cuda<QK2_0, QR2_0, dequantize_q2_0>;
774778
case GGML_TYPE_Q4_0:
775779
return dequantize_row_q4_0_cuda;
776780
case GGML_TYPE_Q4_1:
@@ -828,6 +832,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
828832
return convert_unary_cuda<float>;
829833
case GGML_TYPE_Q1_0:
830834
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
835+
case GGML_TYPE_Q2_0:
836+
return dequantize_block_cuda<QK2_0, QR2_0, dequantize_q2_0>;
831837
case GGML_TYPE_Q4_0:
832838
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
833839
case GGML_TYPE_Q4_1:
@@ -851,6 +857,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
851857
return convert_unary_cuda<float, nv_bfloat16>;
852858
case GGML_TYPE_Q1_0:
853859
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
860+
case GGML_TYPE_Q2_0:
861+
return dequantize_block_cuda<QK2_0, QR2_0, dequantize_q2_0>;
854862
case GGML_TYPE_Q4_0:
855863
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
856864
case GGML_TYPE_Q4_1:
@@ -874,6 +882,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
874882
return convert_unary_cuda<half, float>;
875883
case GGML_TYPE_Q1_0:
876884
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
885+
case GGML_TYPE_Q2_0:
886+
return dequantize_block_cuda<QK2_0, QR2_0, dequantize_q2_0>;
877887
case GGML_TYPE_Q4_0:
878888
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
879889
case GGML_TYPE_Q4_1:

ggml/src/ggml-cuda/dequantize.cuh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,26 @@ static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const in
2222
v.y = (2*bit_1 - 1) * d;
2323
}
2424

25+
static __device__ __forceinline__ void dequantize_q2_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
26+
const block_q2_0 * x = (const block_q2_0 *) vx;
27+
28+
const float d = x[ib].d;
29+
30+
// Q2_0: 2 bits per element, 4 elements per byte.
31+
// Stored code c in {0,1,2,3} maps to symbol s = c - 1 in {-1, 0, +1, +2}.
32+
const int byte_index_0 = iqs / 4;
33+
const int bit_offset_0 = (iqs % 4) * 2;
34+
35+
const int byte_index_1 = (iqs + 1) / 4;
36+
const int bit_offset_1 = ((iqs + 1) % 4) * 2;
37+
38+
const int c0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 0x3;
39+
const int c1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 0x3;
40+
41+
v.x = (c0 - 1) * d;
42+
v.y = (c1 - 1) * d;
43+
}
44+
2545
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
2646
const block_q4_0 * x = (const block_q4_0 *) vx;
2747

ggml/src/ggml-cuda/getrows.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
183183
get_rows_cuda_q<QK1_0, QR1_0, dequantize_q1_0>(src0_d, src1_d, dst_d,
184184
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
185185
break;
186+
case GGML_TYPE_Q2_0:
187+
get_rows_cuda_q<QK2_0, QR2_0, dequantize_q2_0>(src0_d, src1_d, dst_d,
188+
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
189+
break;
186190
case GGML_TYPE_Q4_0:
187191
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
188192
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4841,6 +4841,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
48414841
case GGML_TYPE_F32:
48424842
case GGML_TYPE_F16:
48434843
case GGML_TYPE_Q1_0:
4844+
case GGML_TYPE_Q2_0:
48444845
case GGML_TYPE_Q4_0:
48454846
case GGML_TYPE_Q4_1:
48464847
case GGML_TYPE_Q5_0:
@@ -4879,6 +4880,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
48794880
case GGML_TYPE_BF16:
48804881
case GGML_TYPE_I32:
48814882
case GGML_TYPE_Q1_0:
4883+
case GGML_TYPE_Q2_0:
48824884
case GGML_TYPE_Q4_0:
48834885
case GGML_TYPE_Q4_1:
48844886
case GGML_TYPE_Q5_0:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
88
case GGML_TYPE_Q1_0:
99
mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
1010
break;
11+
case GGML_TYPE_Q2_0:
12+
mul_mat_q_case<GGML_TYPE_Q2_0>(ctx, args, stream);
13+
break;
1114
case GGML_TYPE_Q4_0:
1215
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
1316
break;
@@ -274,6 +277,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
274277

275278
switch (type) {
276279
case GGML_TYPE_Q1_0:
280+
case GGML_TYPE_Q2_0:
277281
case GGML_TYPE_Q4_0:
278282
case GGML_TYPE_Q4_1:
279283
case GGML_TYPE_Q5_0:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
5858
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
5959
switch (type_x) {
6060
case GGML_TYPE_Q1_0:
61+
case GGML_TYPE_Q2_0:
6162
return MMQ_Q8_1_DS_LAYOUT_D4;
6263
case GGML_TYPE_Q4_0:
6364
case GGML_TYPE_Q4_1:
@@ -188,6 +189,7 @@ static constexpr __device__ int get_mmq_y_device() {
188189
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
189190
switch (type) {
190191
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
192+
case GGML_TYPE_Q2_0: return MMQ_DP4A_TXS_Q8_0;
191193
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
192194
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
193195
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -233,6 +235,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
233235
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
234236
switch (type) {
235237
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
238+
case GGML_TYPE_Q2_0: return MMQ_MMA_TILE_X_K_Q8_0;
236239
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
237240
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
238241
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -387,6 +390,101 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
387390
}
388391
}
389392

393+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_0(
394+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
395+
constexpr int nwarps = mmq_get_nwarps_device();
396+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
397+
398+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
399+
int * x_qs = (int *) x_tile;
400+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
401+
#else
402+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
403+
int * x_qs = (int *) x_tile;
404+
float * x_df = (float *) (x_qs + txs.qs);
405+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
406+
407+
constexpr int blocks_per_iter = MMQ_ITER_K / QK2_0;
408+
constexpr int threads_per_row = blocks_per_iter * QI2_0;
409+
constexpr int nrows = warp_size / threads_per_row;
410+
constexpr int scale_entries_per_block = QK2_0 / QK8_1;
411+
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
412+
413+
const int txi = threadIdx.x % threads_per_row;
414+
const int kbx = txi / QI2_0;
415+
const int kqsx = txi % QI2_0;
416+
417+
#pragma unroll
418+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
419+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
420+
421+
if (need_check) {
422+
i = min(i, i_max);
423+
}
424+
425+
const block_q2_0 * bxi = (const block_q2_0 *) x + kbx0 + i*stride + kbx;
426+
// Each 32-element chunk occupies 8 bytes of qs (32 elements * 2 bits = 64 bits)
427+
const int qs_offset = 8*kqsx;
428+
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
429+
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
430+
const int qs1 = bxi->qs[qs_offset + 4] | (bxi->qs[qs_offset + 5] << 8) |
431+
(bxi->qs[qs_offset + 6] << 16) | (bxi->qs[qs_offset + 7] << 24);
432+
433+
// Unpack 32 2-bit codes into 8 int32s, each holding 4 signed int8s in {-1,0,1,2}.
434+
int unpacked_bytes[8];
435+
#pragma unroll
436+
for (int j = 0; j < 4; ++j) {
437+
const int shift = j * 8;
438+
const int codes = (qs0 >> shift) & 0xFF;
439+
const int c0 = ((codes >> 0) & 0x3) - 1;
440+
const int c1 = ((codes >> 2) & 0x3) - 1;
441+
const int c2 = ((codes >> 4) & 0x3) - 1;
442+
const int c3 = ((codes >> 6) & 0x3) - 1;
443+
unpacked_bytes[j] = (c0 & 0xFF) | ((c1 & 0xFF) << 8) | ((c2 & 0xFF) << 16) | ((c3 & 0xFF) << 24);
444+
}
445+
#pragma unroll
446+
for (int j = 0; j < 4; ++j) {
447+
const int shift = j * 8;
448+
const int codes = (qs1 >> shift) & 0xFF;
449+
const int c0 = ((codes >> 0) & 0x3) - 1;
450+
const int c1 = ((codes >> 2) & 0x3) - 1;
451+
const int c2 = ((codes >> 4) & 0x3) - 1;
452+
const int c3 = ((codes >> 6) & 0x3) - 1;
453+
unpacked_bytes[4 + j] = (c0 & 0xFF) | ((c1 & 0xFF) << 8) | ((c2 & 0xFF) << 16) | ((c3 & 0xFF) << 24);
454+
}
455+
456+
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
457+
#pragma unroll
458+
for (int j = 0; j < 8; ++j) {
459+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
460+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
461+
#else
462+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
463+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
464+
}
465+
}
466+
467+
const int ksx = threadIdx.x % scale_entries_per_row;
468+
const int scale_block = ksx / scale_entries_per_block;
469+
470+
#pragma unroll
471+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
472+
int i = i0 + threadIdx.y;
473+
474+
if (need_check) {
475+
i = min(i, i_max);
476+
}
477+
478+
const block_q2_0 * bxi = (const block_q2_0 *) x + kbx0 + i*stride + scale_block;
479+
480+
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
481+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
482+
#else
483+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
484+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
485+
}
486+
}
487+
390488
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
391489
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
392490
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3383,6 +3481,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
33833481
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
33843482
};
33853483

3484+
template <int mmq_x, int mmq_y, bool need_check>
3485+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_0> {
3486+
static constexpr int vdr = VDR_Q2_0_Q8_1_MMQ;
3487+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_0<mmq_y, need_check>;
3488+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3489+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3490+
};
3491+
33863492
template <int mmq_x, int mmq_y, bool need_check>
33873493
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
33883494
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
1010
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
1111
switch (type) {
1212
case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1;
13+
case GGML_TYPE_Q2_0: return vec_dot_q2_0_q8_1;
1314
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
1415
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
1516
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
3839
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
3940
switch (type) {
4041
case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ;
42+
case GGML_TYPE_Q2_0: return VDR_Q2_0_Q8_1_MMVQ;
4143
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
4244
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
4345
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
@@ -894,6 +896,12 @@ static void mul_mat_vec_q_switch_type(
894896
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
895897
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
896898
break;
899+
case GGML_TYPE_Q2_0:
900+
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_0>
901+
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
902+
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
903+
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
904+
break;
897905
case GGML_TYPE_Q4_0:
898906
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
899907
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
3333

3434
TYPES_MMQ = [
35-
"GGML_TYPE_Q1_0",
35+
"GGML_TYPE_Q1_0", "GGML_TYPE_Q2_0",
3636
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
3737
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
3838
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",

0 commit comments

Comments
 (0)