@@ -23,7 +23,7 @@ typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t *
2323
2424template <ggml_type type>
2525static constexpr __host__ __device__ bool mmq_type_has_mma () {
26- return type != GGML_TYPE_PQ2_K;
26+ return type != GGML_TYPE_PQ2_K && type != GGML_TYPE_PQ3_K && type != GGML_TYPE_PQ4_K ;
2727}
2828
2929template <ggml_type type>
@@ -111,20 +111,22 @@ static constexpr __device__ int mmq_get_granularity_device_for_type(const int mm
111111
112112template <ggml_type type>
113113static void mmq_log_selected_path_once (const int cc, const int mmq_x, const int mmq_y) {
114- if constexpr (type != GGML_TYPE_PQ2_K) {
114+ if constexpr (type != GGML_TYPE_PQ2_K && type != GGML_TYPE_PQ3_K && type != GGML_TYPE_PQ4_K ) {
115115 return ;
116116 }
117117
118- static const bool enabled = getenv (" GGML_CUDA_LOG_PQ2_K_MMQ " ) != nullptr ;
118+ static const bool enabled = getenv (" GGML_CUDA_LOG_PQK_MMQ " ) != nullptr ;
119119 static bool logged = false ;
120120
121121 if (!enabled || logged) {
122122 return ;
123123 }
124124
125125 logged = true ;
126- GGML_LOG_INFO (" %s: PQ2_K MMQ using %s path (cc=%d, mmq_x=%d, mmq_y=%d)\n " ,
127- __func__, mmq_uses_mma_host<type>(cc) ? " mma" : " dp4a" , cc, mmq_x, mmq_y);
126+ const char * type_name = type == GGML_TYPE_PQ2_K ? " PQ2_K" : (type == GGML_TYPE_PQ3_K ? " PQ3_K" : " PQ4_K" );
127+ GGML_LOG_INFO (" %s: %s MMQ using %s path (cc=%d, mmq_x=%d, mmq_y=%d)\n " ,
128+ __func__, type_name,
129+ mmq_uses_mma_host<type>(cc) ? " mma" : " dp4a" , cc, mmq_x, mmq_y);
128130}
129131
130132enum mmq_q8_1_ds_layout {
@@ -308,8 +310,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
308310 case GGML_TYPE_PQ3_0: return MMQ_DP4A_TXS_Q8_0;
309311 case GGML_TYPE_PQ4_0: return MMQ_DP4A_TXS_Q8_0;
310312 case GGML_TYPE_PQ2_K: return MMQ_DP4A_TXS_PQ2_K;
311- case GGML_TYPE_PQ3_K: return MMQ_DP4A_TXS_Q8_0_16 ;
312- case GGML_TYPE_PQ4_K: return MMQ_DP4A_TXS_Q8_0_16 ;
313+ case GGML_TYPE_PQ3_K: return MMQ_DP4A_TXS_PQ2_K ;
314+ case GGML_TYPE_PQ4_K: return MMQ_DP4A_TXS_PQ2_K ;
313315 case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
314316 case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
315317 case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -1455,6 +1457,132 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
14551457 }
14561458}
14571459
1460+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_pq3_K (
1461+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1462+ constexpr int nwarps = mmq_get_nwarps_device ();
1463+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
1464+
1465+ // PQ3_K uses the same 8D local-scale shared-memory layout as PQ2_K.
1466+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_PQ3_K, mmq_y);
1467+ int * x_qs = (int *) x_tile;
1468+ half2 * x_d = (half2 *) (x_qs + txs.qs );
1469+
1470+ constexpr int threads_per_row = 32 ;
1471+ constexpr int nrows = warp_size / threads_per_row;
1472+ const int txi = warp_size > threads_per_row ? threadIdx .x % threads_per_row : threadIdx .x ;
1473+
1474+ #pragma unroll
1475+ for (int i0 = 0 ; i0 < mmq_y; i0 += nrows*nwarps) {
1476+ int i = i0 + (nrows == 1 ? threadIdx .y : threadIdx .y *nrows + threadIdx .x /threads_per_row);
1477+
1478+ if (need_check) {
1479+ i = min (i, i_max);
1480+ }
1481+
1482+ const block_pq3_K * bxi = (const block_pq3_K *) x + kbx0 + i*stride;
1483+ const int elem0 = 4 *txi;
1484+ const int elem1 = 128 + 4 *txi;
1485+ const uint8_t high0 = (bxi->hmask [elem0 >> 3 ] >> (elem0 & 7 )) & 0x0Fu ;
1486+ const uint8_t high1 = (bxi->hmask [elem1 >> 3 ] >> (elem1 & 7 )) & 0x0Fu ;
1487+ const uint8_t qb0 = bxi->qs [txi];
1488+ const uint8_t qb1 = bxi->qs [MMQ_TILE_NE_K + txi];
1489+ const int q4_0 = ((((qb0 >> 0 ) & 0x03u ) | ((high0 & 0x01u ) << 2 )) << 0 )
1490+ | ((((qb0 >> 2 ) & 0x03u ) | ((high0 & 0x02u ) << 1 )) << 4 )
1491+ | ((((qb0 >> 4 ) & 0x03u ) | (high0 & 0x04u )) << 8 )
1492+ | ((((qb0 >> 6 ) & 0x03u ) | ((high0 & 0x08u ) >> 1 )) << 12 );
1493+ const int q4_1 = ((((qb1 >> 0 ) & 0x03u ) | ((high1 & 0x01u ) << 2 )) << 0 )
1494+ | ((((qb1 >> 2 ) & 0x03u ) | ((high1 & 0x02u ) << 1 )) << 4 )
1495+ | ((((qb1 >> 4 ) & 0x03u ) | (high1 & 0x04u )) << 8 )
1496+ | ((((qb1 >> 6 ) & 0x03u ) | ((high1 & 0x08u ) >> 1 )) << 12 );
1497+ const int2 vp = get_int_from_table_16 (q4_0 | (q4_1 << 16 ), PQK_DP4A_VAL_3BIT_16);
1498+ const int qs0 = __byte_perm (vp.x , vp.y , 0x5140 );
1499+ const int qs1 = __byte_perm (vp.x , vp.y , 0x7362 );
1500+
1501+ x_qs[i*(2 *MMQ_TILE_NE_K + 1 ) + 0 + txi] = qs0;
1502+ x_qs[i*(2 *MMQ_TILE_NE_K + 1 ) + MMQ_TILE_NE_K + txi] = qs1;
1503+ }
1504+
1505+ constexpr int scale_pairs_per_row = GGML_PQ3_K_SUBBLOCK_COUNT / 2 ;
1506+ constexpr int scale_rows_per_warp = warp_size / scale_pairs_per_row;
1507+ const int ksp = threadIdx .x % scale_pairs_per_row;
1508+
1509+ #pragma unroll
1510+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps*scale_rows_per_warp) {
1511+ int i = i0 + threadIdx .y *scale_rows_per_warp + threadIdx .x /scale_pairs_per_row;
1512+
1513+ if (need_check) {
1514+ i = min (i, i_max);
1515+ }
1516+
1517+ const block_pq3_K * bxi = (const block_pq3_K *) x + kbx0 + i*stride;
1518+ const int sub0 = 2 *ksp;
1519+ const int band = sub0 / GGML_PQ3_K_SUBBLOCKS_PER_BAND;
1520+ const float dbase = __half2float (bxi->d [band]) * PQK_DP4A_INV_SCALE_3BIT;
1521+ const uint8_t qscale_pair = bxi->scales [ksp];
1522+ const half2 d_pair = make_half2 (
1523+ dbase * PQ3_K_LOCAL_SCALE_LUT[qscale_pair & 0x0Fu ],
1524+ dbase * PQ3_K_LOCAL_SCALE_LUT[qscale_pair >> 4 ]);
1525+
1526+ x_d[i*(MMQ_TILE_NE_K/2 + 1 ) + ksp] = d_pair;
1527+ }
1528+ }
1529+
1530+ template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_pq4_K (
1531+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1532+ constexpr int nwarps = mmq_get_nwarps_device ();
1533+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
1534+
1535+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes (GGML_TYPE_PQ4_K, mmq_y);
1536+ int * x_qs = (int *) x_tile;
1537+ half2 * x_d = (half2 *) (x_qs + txs.qs );
1538+
1539+ constexpr int threads_per_row = 32 ;
1540+ constexpr int nrows = warp_size / threads_per_row;
1541+ const int txi = warp_size > threads_per_row ? threadIdx .x % threads_per_row : threadIdx .x ;
1542+
1543+ #pragma unroll
1544+ for (int i0 = 0 ; i0 < mmq_y; i0 += nrows*nwarps) {
1545+ int i = i0 + (nrows == 1 ? threadIdx .y : threadIdx .y *nrows + threadIdx .x /threads_per_row);
1546+
1547+ if (need_check) {
1548+ i = min (i, i_max);
1549+ }
1550+
1551+ const block_pq4_K * bxi = (const block_pq4_K *) x + kbx0 + i*stride;
1552+ const uint16_t * q16 = (const uint16_t *) bxi->qs ;
1553+ const int2 vp = get_int_from_table_16 ((int ) q16[txi] | ((int ) q16[MMQ_TILE_NE_K + txi] << 16 ), PQK_DP4A_VAL_4BIT);
1554+ const int qs0 = __byte_perm (vp.x , vp.y , 0x5140 );
1555+ const int qs1 = __byte_perm (vp.x , vp.y , 0x7362 );
1556+
1557+ x_qs[i*(2 *MMQ_TILE_NE_K + 1 ) + 0 + txi] = qs0;
1558+ x_qs[i*(2 *MMQ_TILE_NE_K + 1 ) + MMQ_TILE_NE_K + txi] = qs1;
1559+ }
1560+
1561+ constexpr int scale_pairs_per_row = GGML_PQ4_K_SUBBLOCK_COUNT / 2 ;
1562+ constexpr int scale_rows_per_warp = warp_size / scale_pairs_per_row;
1563+ const int ksp = threadIdx .x % scale_pairs_per_row;
1564+
1565+ #pragma unroll
1566+ for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps*scale_rows_per_warp) {
1567+ int i = i0 + threadIdx .y *scale_rows_per_warp + threadIdx .x /scale_pairs_per_row;
1568+
1569+ if (need_check) {
1570+ i = min (i, i_max);
1571+ }
1572+
1573+ const block_pq4_K * bxi = (const block_pq4_K *) x + kbx0 + i*stride;
1574+ const int sub0 = 2 *ksp;
1575+ const int band = sub0 / GGML_PQ4_K_SUBBLOCKS_PER_BAND;
1576+ const float dbase = __half2float (bxi->d [band]) * PQK_DP4A_INV_SCALE_4BIT;
1577+ const uint8_t qscale_pair = bxi->scales [ksp];
1578+ const half2 d_pair = make_half2 (
1579+ dbase * PQ4_K_LOCAL_SCALE_LUT[qscale_pair & 0x0Fu ],
1580+ dbase * PQ4_K_LOCAL_SCALE_LUT[qscale_pair >> 4 ]);
1581+
1582+ x_d[i*(MMQ_TILE_NE_K/2 + 1 ) + ksp] = d_pair;
1583+ }
1584+ }
1585+
14581586template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4 (
14591587 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
14601588 constexpr int nwarps = mmq_get_nwarps_device ();
@@ -3946,17 +4074,17 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_PQ2_K> {
39464074template <int mmq_x, int mmq_y, bool need_check>
39474075struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_PQ3_K> {
39484076 static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
3949- static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq_K<GGML_TYPE_PQ3_K, mmq_y, need_check>;
3950- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma <mmq_x, mmq_y>;
3951- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a <mmq_x, mmq_y>;
4077+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq3_K< mmq_y, need_check>;
4078+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_8_q8_1_mma <mmq_x, mmq_y>;
4079+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_8_q8_1_dp4a <mmq_x, mmq_y>;
39524080};
39534081
39544082template <int mmq_x, int mmq_y, bool need_check>
39554083struct mmq_type_traits <mmq_x, mmq_y, need_check, GGML_TYPE_PQ4_K> {
39564084 static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
3957- static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq_K<GGML_TYPE_PQ4_K, mmq_y, need_check>;
3958- static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma <mmq_x, mmq_y>;
3959- static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a <mmq_x, mmq_y>;
4085+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq4_K< mmq_y, need_check>;
4086+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_8_q8_1_mma <mmq_x, mmq_y>;
4087+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_8_q8_1_dp4a <mmq_x, mmq_y>;
39604088};
39614089
39624090template <int mmq_x, int mmq_y, bool need_check>
0 commit comments