@@ -58,6 +58,7 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
5858static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout (const ggml_type type_x) {
5959 switch (type_x) {
6060 case GGML_TYPE_Q1_0:
61+ case GGML_TYPE_Q2_0:
6162 return MMQ_Q8_1_DS_LAYOUT_D4;
6263 case GGML_TYPE_Q4_0:
6364 case GGML_TYPE_Q4_1:
@@ -188,6 +189,7 @@ static constexpr __device__ int get_mmq_y_device() {
188189static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
189190 switch (type) {
190191 case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
192+ case GGML_TYPE_Q2_0: return MMQ_DP4A_TXS_Q8_0;
191193 case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
192194 case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
193195 case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -233,6 +235,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
233235static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
234236 switch (type) {
235237 case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
238+ case GGML_TYPE_Q2_0: return MMQ_MMA_TILE_X_K_Q8_0;
236239 case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
237240 case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
238241 case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@@ -387,6 +390,101 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
387390 }
388391}
389392
393+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_0 (
394+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
395+ constexpr int nwarps = mmq_get_nwarps_device ();
396+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
397+
398+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
399+ int * x_qs = (int *) x_tile;
400+ float * x_df = (float *) (x_qs + 2 *MMQ_TILE_NE_K);
401+ #else
402+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_Q8_0, mmq_y);
403+ int * x_qs = (int *) x_tile;
404+ float * x_df = (float *) (x_qs + txs.qs );
405+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
406+
407+ constexpr int blocks_per_iter = MMQ_ITER_K / QK2_0;
408+ constexpr int threads_per_row = blocks_per_iter * QI2_0;
409+ constexpr int nrows = warp_size / threads_per_row;
410+ constexpr int scale_entries_per_block = QK2_0 / QK8_1;
411+ constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
412+
413+ const int txi = threadIdx .x % threads_per_row;
414+ const int kbx = txi / QI2_0;
415+ const int kqsx = txi % QI2_0;
416+
417+ #pragma unroll
418+ for (int i0 = 0 ; i0 < mmq_y; i0 += nrows*nwarps) {
419+ int i = i0 + threadIdx .y *nrows + threadIdx .x /threads_per_row;
420+
421+ if (need_check) {
422+ i = min (i, i_max);
423+ }
424+
425+ const block_q2_0 * bxi = (const block_q2_0 *) x + kbx0 + i*stride + kbx;
426+ // Each 32-element chunk occupies 8 bytes of qs (32 elements * 2 bits = 64 bits)
427+ const int qs_offset = 8 *kqsx;
428+ const int qs0 = bxi->qs [qs_offset + 0 ] | (bxi->qs [qs_offset + 1 ] << 8 ) |
429+ (bxi->qs [qs_offset + 2 ] << 16 ) | (bxi->qs [qs_offset + 3 ] << 24 );
430+ const int qs1 = bxi->qs [qs_offset + 4 ] | (bxi->qs [qs_offset + 5 ] << 8 ) |
431+ (bxi->qs [qs_offset + 6 ] << 16 ) | (bxi->qs [qs_offset + 7 ] << 24 );
432+
433+ // Unpack 32 2-bit codes into 8 int32s, each holding 4 signed int8s in {-1,0,1,2}.
434+ int unpacked_bytes[8 ];
435+ #pragma unroll
436+ for (int j = 0 ; j < 4 ; ++j) {
437+ const int shift = j * 8 ;
438+ const int codes = (qs0 >> shift) & 0xFF ;
439+ const int c0 = ((codes >> 0 ) & 0x3 ) - 1 ;
440+ const int c1 = ((codes >> 2 ) & 0x3 ) - 1 ;
441+ const int c2 = ((codes >> 4 ) & 0x3 ) - 1 ;
442+ const int c3 = ((codes >> 6 ) & 0x3 ) - 1 ;
443+ unpacked_bytes[j] = (c0 & 0xFF ) | ((c1 & 0xFF ) << 8 ) | ((c2 & 0xFF ) << 16 ) | ((c3 & 0xFF ) << 24 );
444+ }
445+ #pragma unroll
446+ for (int j = 0 ; j < 4 ; ++j) {
447+ const int shift = j * 8 ;
448+ const int codes = (qs1 >> shift) & 0xFF ;
449+ const int c0 = ((codes >> 0 ) & 0x3 ) - 1 ;
450+ const int c1 = ((codes >> 2 ) & 0x3 ) - 1 ;
451+ const int c2 = ((codes >> 4 ) & 0x3 ) - 1 ;
452+ const int c3 = ((codes >> 6 ) & 0x3 ) - 1 ;
453+ unpacked_bytes[4 + j] = (c0 & 0xFF ) | ((c1 & 0xFF ) << 8 ) | ((c2 & 0xFF ) << 16 ) | ((c3 & 0xFF ) << 24 );
454+ }
455+
456+ const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
457+ #pragma unroll
458+ for (int j = 0 ; j < 8 ; ++j) {
459+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
460+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
461+ #else
462+ x_qs[i*(2 *MMQ_TILE_NE_K + 1 ) + dst_offset + j] = unpacked_bytes[j];
463+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
464+ }
465+ }
466+
467+ const int ksx = threadIdx .x % scale_entries_per_row;
468+ const int scale_block = ksx / scale_entries_per_block;
469+
470+ #pragma unroll
471+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps) {
472+ int i = i0 + threadIdx .y ;
473+
474+ if (need_check) {
475+ i = min (i, i_max);
476+ }
477+
478+ const block_q2_0 * bxi = (const block_q2_0 *) x + kbx0 + i*stride + scale_block;
479+
480+ #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
481+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d ;
482+ #else
483+ x_df[i*(2 *MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2 ) + ksx] = bxi->d ;
484+ #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
485+ }
486+ }
487+
390488template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0 (
391489 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
392490 constexpr int nwarps = mmq_get_nwarps_device ();
@@ -3383,6 +3481,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
33833481 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
33843482};
33853483
3484+ template <int mmq_x, int mmq_y, bool need_check>
3485+ struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_Q2_0> {
3486+ static constexpr int vdr = VDR_Q2_0_Q8_1_MMQ;
3487+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_0<mmq_y, need_check>;
3488+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3489+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3490+ };
3491+
33863492template <int mmq_x, int mmq_y, bool need_check>
33873493struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
33883494 static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
0 commit comments