Skip to content

Commit dc60b16

Browse files
committed
Add 8D local-scale and AW refinement for PQ3_K/PQ4_K
- switch PQ3_K/PQ4_K to per-8D local scales - add greedy R2P1 imatrix-aware refinement on the quant path - update CUDA mmq, vecdot and dequant support for the new layout
1 parent 949f427 commit dc60b16

8 files changed

Lines changed: 1386 additions & 154 deletions

File tree

ggml/src/ggml-common.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,19 +405,19 @@ typedef struct {
405405
static_assert(sizeof(block_pq2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong pq2_K block size/padding");
406406

407407
typedef struct {
408-
ggml_half d[2];
409-
uint8_t scales[K_SCALE_SIZE];
408+
ggml_half d[2]; // master band scales for two 128-wide bands
409+
uint8_t scales[QK_K/16]; // 4-bit local scales for each 8D sub-block
410410
uint8_t hmask[QK_K/8];
411411
uint8_t qs[QK_K/4];
412412
} block_pq3_K;
413-
static_assert(sizeof(block_pq3_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/8 + QK_K/4, "wrong pq3_K block size/padding");
413+
static_assert(sizeof(block_pq3_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/8 + QK_K/4, "wrong pq3_K block size/padding");
414414

415415
typedef struct {
416-
ggml_half d[2];
417-
uint8_t scales[K_SCALE_SIZE];
416+
ggml_half d[2]; // master band scales for two 128-wide bands
417+
uint8_t scales[QK_K/16]; // 4-bit local scales for each 8D sub-block
418418
uint8_t qs[QK_K/2];
419419
} block_pq4_K;
420-
static_assert(sizeof(block_pq4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong pq4_K block size/padding");
420+
static_assert(sizeof(block_pq4_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/2, "wrong pq4_K block size/padding");
421421

422422
// 5-bit quantization
423423
// 8 blocks of 32 elements each

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,14 +995,14 @@ template<>
995995
struct ggml_cuda_type_traits<GGML_TYPE_PQ3_K> {
996996
static constexpr int qk = QK_K;
997997
static constexpr int qr = 1;
998-
static constexpr int qi = 16;
998+
static constexpr int qi = GGML_PQ3_K_SUBBLOCK_COUNT;
999999
};
10001000

10011001
template<>
10021002
struct ggml_cuda_type_traits<GGML_TYPE_PQ4_K> {
10031003
static constexpr int qk = QK_K;
10041004
static constexpr int qr = 1;
1005-
static constexpr int qi = 16;
1005+
static constexpr int qi = GGML_PQ4_K_SUBBLOCK_COUNT;
10061006
};
10071007

10081008
template<>

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 141 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t *
2323

2424
template <ggml_type type>
2525
static constexpr __host__ __device__ bool mmq_type_has_mma() {
26-
return type != GGML_TYPE_PQ2_K;
26+
return type != GGML_TYPE_PQ2_K && type != GGML_TYPE_PQ3_K && type != GGML_TYPE_PQ4_K;
2727
}
2828

2929
template <ggml_type type>
@@ -111,20 +111,22 @@ static constexpr __device__ int mmq_get_granularity_device_for_type(const int mm
111111

112112
template <ggml_type type>
113113
static void mmq_log_selected_path_once(const int cc, const int mmq_x, const int mmq_y) {
114-
if constexpr (type != GGML_TYPE_PQ2_K) {
114+
if constexpr (type != GGML_TYPE_PQ2_K && type != GGML_TYPE_PQ3_K && type != GGML_TYPE_PQ4_K) {
115115
return;
116116
}
117117

118-
static const bool enabled = getenv("GGML_CUDA_LOG_PQ2_K_MMQ") != nullptr;
118+
static const bool enabled = getenv("GGML_CUDA_LOG_PQK_MMQ") != nullptr;
119119
static bool logged = false;
120120

121121
if (!enabled || logged) {
122122
return;
123123
}
124124

125125
logged = true;
126-
GGML_LOG_INFO("%s: PQ2_K MMQ using %s path (cc=%d, mmq_x=%d, mmq_y=%d)\n",
127-
__func__, mmq_uses_mma_host<type>(cc) ? "mma" : "dp4a", cc, mmq_x, mmq_y);
126+
const char * type_name = type == GGML_TYPE_PQ2_K ? "PQ2_K" : (type == GGML_TYPE_PQ3_K ? "PQ3_K" : "PQ4_K");
127+
GGML_LOG_INFO("%s: %s MMQ using %s path (cc=%d, mmq_x=%d, mmq_y=%d)\n",
128+
__func__, type_name,
129+
mmq_uses_mma_host<type>(cc) ? "mma" : "dp4a", cc, mmq_x, mmq_y);
128130
}
129131

130132
enum mmq_q8_1_ds_layout {
@@ -308,8 +310,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
308310
case GGML_TYPE_PQ3_0: return MMQ_DP4A_TXS_Q8_0;
309311
case GGML_TYPE_PQ4_0: return MMQ_DP4A_TXS_Q8_0;
310312
case GGML_TYPE_PQ2_K: return MMQ_DP4A_TXS_PQ2_K;
311-
case GGML_TYPE_PQ3_K: return MMQ_DP4A_TXS_Q8_0_16;
312-
case GGML_TYPE_PQ4_K: return MMQ_DP4A_TXS_Q8_0_16;
313+
case GGML_TYPE_PQ3_K: return MMQ_DP4A_TXS_PQ2_K;
314+
case GGML_TYPE_PQ4_K: return MMQ_DP4A_TXS_PQ2_K;
313315
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
314316
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
315317
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@@ -1455,6 +1457,132 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
14551457
}
14561458
}
14571459

1460+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_pq3_K(
1461+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1462+
constexpr int nwarps = mmq_get_nwarps_device();
1463+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1464+
1465+
// PQ3_K uses the same 8D local-scale shared-memory layout as PQ2_K.
1466+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_PQ3_K, mmq_y);
1467+
int * x_qs = (int *) x_tile;
1468+
half2 * x_d = (half2 *) (x_qs + txs.qs);
1469+
1470+
constexpr int threads_per_row = 32;
1471+
constexpr int nrows = warp_size / threads_per_row;
1472+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1473+
1474+
#pragma unroll
1475+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1476+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1477+
1478+
if (need_check) {
1479+
i = min(i, i_max);
1480+
}
1481+
1482+
const block_pq3_K * bxi = (const block_pq3_K *) x + kbx0 + i*stride;
1483+
const int elem0 = 4*txi;
1484+
const int elem1 = 128 + 4*txi;
1485+
const uint8_t high0 = (bxi->hmask[elem0 >> 3] >> (elem0 & 7)) & 0x0Fu;
1486+
const uint8_t high1 = (bxi->hmask[elem1 >> 3] >> (elem1 & 7)) & 0x0Fu;
1487+
const uint8_t qb0 = bxi->qs[txi];
1488+
const uint8_t qb1 = bxi->qs[MMQ_TILE_NE_K + txi];
1489+
const int q4_0 = ((((qb0 >> 0) & 0x03u) | ((high0 & 0x01u) << 2)) << 0)
1490+
| ((((qb0 >> 2) & 0x03u) | ((high0 & 0x02u) << 1)) << 4)
1491+
| ((((qb0 >> 4) & 0x03u) | (high0 & 0x04u)) << 8)
1492+
| ((((qb0 >> 6) & 0x03u) | ((high0 & 0x08u) >> 1)) << 12);
1493+
const int q4_1 = ((((qb1 >> 0) & 0x03u) | ((high1 & 0x01u) << 2)) << 0)
1494+
| ((((qb1 >> 2) & 0x03u) | ((high1 & 0x02u) << 1)) << 4)
1495+
| ((((qb1 >> 4) & 0x03u) | (high1 & 0x04u)) << 8)
1496+
| ((((qb1 >> 6) & 0x03u) | ((high1 & 0x08u) >> 1)) << 12);
1497+
const int2 vp = get_int_from_table_16(q4_0 | (q4_1 << 16), PQK_DP4A_VAL_3BIT_16);
1498+
const int qs0 = __byte_perm(vp.x, vp.y, 0x5140);
1499+
const int qs1 = __byte_perm(vp.x, vp.y, 0x7362);
1500+
1501+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = qs0;
1502+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = qs1;
1503+
}
1504+
1505+
constexpr int scale_pairs_per_row = GGML_PQ3_K_SUBBLOCK_COUNT / 2;
1506+
constexpr int scale_rows_per_warp = warp_size / scale_pairs_per_row;
1507+
const int ksp = threadIdx.x % scale_pairs_per_row;
1508+
1509+
#pragma unroll
1510+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*scale_rows_per_warp) {
1511+
int i = i0 + threadIdx.y*scale_rows_per_warp + threadIdx.x/scale_pairs_per_row;
1512+
1513+
if (need_check) {
1514+
i = min(i, i_max);
1515+
}
1516+
1517+
const block_pq3_K * bxi = (const block_pq3_K *) x + kbx0 + i*stride;
1518+
const int sub0 = 2*ksp;
1519+
const int band = sub0 / GGML_PQ3_K_SUBBLOCKS_PER_BAND;
1520+
const float dbase = __half2float(bxi->d[band]) * PQK_DP4A_INV_SCALE_3BIT;
1521+
const uint8_t qscale_pair = bxi->scales[ksp];
1522+
const half2 d_pair = make_half2(
1523+
dbase * PQ3_K_LOCAL_SCALE_LUT[qscale_pair & 0x0Fu],
1524+
dbase * PQ3_K_LOCAL_SCALE_LUT[qscale_pair >> 4]);
1525+
1526+
x_d[i*(MMQ_TILE_NE_K/2 + 1) + ksp] = d_pair;
1527+
}
1528+
}
1529+
1530+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_pq4_K(
1531+
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1532+
constexpr int nwarps = mmq_get_nwarps_device();
1533+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1534+
1535+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_PQ4_K, mmq_y);
1536+
int * x_qs = (int *) x_tile;
1537+
half2 * x_d = (half2 *) (x_qs + txs.qs);
1538+
1539+
constexpr int threads_per_row = 32;
1540+
constexpr int nrows = warp_size / threads_per_row;
1541+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1542+
1543+
#pragma unroll
1544+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1545+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1546+
1547+
if (need_check) {
1548+
i = min(i, i_max);
1549+
}
1550+
1551+
const block_pq4_K * bxi = (const block_pq4_K *) x + kbx0 + i*stride;
1552+
const uint16_t * q16 = (const uint16_t *) bxi->qs;
1553+
const int2 vp = get_int_from_table_16((int) q16[txi] | ((int) q16[MMQ_TILE_NE_K + txi] << 16), PQK_DP4A_VAL_4BIT);
1554+
const int qs0 = __byte_perm(vp.x, vp.y, 0x5140);
1555+
const int qs1 = __byte_perm(vp.x, vp.y, 0x7362);
1556+
1557+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = qs0;
1558+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = qs1;
1559+
}
1560+
1561+
constexpr int scale_pairs_per_row = GGML_PQ4_K_SUBBLOCK_COUNT / 2;
1562+
constexpr int scale_rows_per_warp = warp_size / scale_pairs_per_row;
1563+
const int ksp = threadIdx.x % scale_pairs_per_row;
1564+
1565+
#pragma unroll
1566+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*scale_rows_per_warp) {
1567+
int i = i0 + threadIdx.y*scale_rows_per_warp + threadIdx.x/scale_pairs_per_row;
1568+
1569+
if (need_check) {
1570+
i = min(i, i_max);
1571+
}
1572+
1573+
const block_pq4_K * bxi = (const block_pq4_K *) x + kbx0 + i*stride;
1574+
const int sub0 = 2*ksp;
1575+
const int band = sub0 / GGML_PQ4_K_SUBBLOCKS_PER_BAND;
1576+
const float dbase = __half2float(bxi->d[band]) * PQK_DP4A_INV_SCALE_4BIT;
1577+
const uint8_t qscale_pair = bxi->scales[ksp];
1578+
const half2 d_pair = make_half2(
1579+
dbase * PQ4_K_LOCAL_SCALE_LUT[qscale_pair & 0x0Fu],
1580+
dbase * PQ4_K_LOCAL_SCALE_LUT[qscale_pair >> 4]);
1581+
1582+
x_d[i*(MMQ_TILE_NE_K/2 + 1) + ksp] = d_pair;
1583+
}
1584+
}
1585+
14581586
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
14591587
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
14601588
constexpr int nwarps = mmq_get_nwarps_device();
@@ -3946,17 +4074,17 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_PQ2_K> {
39464074
template <int mmq_x, int mmq_y, bool need_check>
39474075
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_PQ3_K> {
39484076
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
3949-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq_K<GGML_TYPE_PQ3_K, mmq_y, need_check>;
3950-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3951-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
4077+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq3_K<mmq_y, need_check>;
4078+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_8_q8_1_mma<mmq_x, mmq_y>;
4079+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_8_q8_1_dp4a<mmq_x, mmq_y>;
39524080
};
39534081

39544082
template <int mmq_x, int mmq_y, bool need_check>
39554083
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_PQ4_K> {
39564084
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
3957-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq_K<GGML_TYPE_PQ4_K, mmq_y, need_check>;
3958-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3959-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
4085+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_pq4_K<mmq_y, need_check>;
4086+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_8_q8_1_mma<mmq_x, mmq_y>;
4087+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_8_q8_1_dp4a<mmq_x, mmq_y>;
39604088
};
39614089

39624090
template <int mmq_x, int mmq_y, bool need_check>

ggml/src/ggml-cuda/pq-tq-common.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,20 @@ __constant__ static const float PQ2_K_LOCAL_SCALE_LUT[16] = {
309309
0.5520447568f, 0.6729500963f, 0.820335356f, 1.0f
310310
};
311311

312+
__constant__ static const float PQ3_K_LOCAL_SCALE_LUT[16] = {
313+
0.0f, 0.0625f, 0.07618835339f, 0.09287464307f,
314+
0.113215458f, 0.1380111892f, 0.1682375241f, 0.205083839f,
315+
0.25f, 0.3047534136f, 0.3714985723f, 0.4528618321f,
316+
0.5520447568f, 0.6729500963f, 0.820335356f, 1.0f
317+
};
318+
319+
__constant__ static const float PQ4_K_LOCAL_SCALE_LUT[16] = {
320+
0.0f, 0.0625f, 0.07618835339f, 0.09287464307f,
321+
0.113215458f, 0.1380111892f, 0.1682375241f, 0.205083839f,
322+
0.25f, 0.3047534136f, 0.3714985723f, 0.4528618321f,
323+
0.5520447568f, 0.6729500963f, 0.820335356f, 1.0f
324+
};
325+
312326
#undef PQK_DP4A_VAL4_ENTRY
313327
#undef PQK_DP4A_VAL3_ENTRY
314328
#undef PQK_DP4A_VAL2_ENTRY

ggml/src/ggml-cuda/pq-tq-dequant-wht.cuh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ static __device__ __forceinline__ float pq2_k_dequant_scale(const block_pq2_K *
7777
return ggml_pq2_k_decode_local_scale(master, ggml_pq2_k_scale_get(x[ib].scales, subblock));
7878
}
7979

80+
static __device__ __forceinline__ float pq3_k_dequant_scale(const block_pq3_K * x, const int ib, const int subblock) {
81+
const int band = subblock / GGML_PQ3_K_SUBBLOCKS_PER_BAND;
82+
const float master = __half2float(x[ib].d[band]);
83+
return ggml_pq3_k_decode_local_scale(master, ggml_pq3_k_scale_get(x[ib].scales, subblock));
84+
}
85+
86+
static __device__ __forceinline__ float pq4_k_dequant_scale(const block_pq4_K * x, const int ib, const int subblock) {
87+
const int band = subblock / GGML_PQ4_K_SUBBLOCKS_PER_BAND;
88+
const float master = __half2float(x[ib].d[band]);
89+
return ggml_pq4_k_decode_local_scale(master, ggml_pq4_k_scale_get(x[ib].scales, subblock));
90+
}
91+
8092
static __device__ __forceinline__ float pq_dequant_elem_2_k(const void * vx, int64_t global_elem) {
8193
const block_pq2_K * x = (const block_pq2_K *) vx;
8294
const int ib = global_elem / QK_K;
@@ -91,8 +103,8 @@ static __device__ __forceinline__ float pq_dequant_elem_3_k(const void * vx, int
91103
const block_pq3_K * x = (const block_pq3_K *) vx;
92104
const int ib = global_elem / QK_K;
93105
const int il = global_elem % QK_K;
94-
const int subblock = il / GGML_PQK_SUBBLOCK_SIZE;
95-
const float scale = pqk_dequant_scale(x, ib, subblock);
106+
const int subblock = il / GGML_PQ3_K_SUBBLOCK_SIZE;
107+
const float scale = pq3_k_dequant_scale(x, ib, subblock);
96108
const uint8_t ql = (x[ib].qs[il / 4] >> (2 * (il & 3))) & 0x3u;
97109
const uint8_t qh = (x[ib].hmask[il / 8] >> (il & 7)) & 0x1u;
98110
return ggml_pqk_centroid_3bit((uint8_t)(ql | (qh << 2))) * scale;
@@ -102,7 +114,7 @@ static __device__ __forceinline__ float pq_dequant_elem_4_k(const void * vx, int
102114
const block_pq4_K * x = (const block_pq4_K *) vx;
103115
const int ib = global_elem / QK_K;
104116
const int il = global_elem % QK_K;
105-
const float scale = pqk_dequant_scale(x, ib, il / GGML_PQK_SUBBLOCK_SIZE);
117+
const float scale = pq4_k_dequant_scale(x, ib, il / GGML_PQ4_K_SUBBLOCK_SIZE);
106118
const uint8_t q = (x[ib].qs[il / 2] >> (4 * (il & 1))) & 0xFu;
107119
return ggml_pqk_centroid_4bit(q) * scale;
108120
}
@@ -306,8 +318,8 @@ __device__ __forceinline__ float2 pq_tq_dequant_pair<PqTqTypeTag::P3_K>(const vo
306318
const int ib = global_pair / pairs_per_block;
307319
const int ip = global_pair % pairs_per_block;
308320
const int il = 2 * ip;
309-
const int subblock = il / GGML_PQK_SUBBLOCK_SIZE;
310-
const float scale = pqk_dequant_scale(x, ib, subblock);
321+
const int subblock = il / GGML_PQ3_K_SUBBLOCK_SIZE;
322+
const float scale = pq3_k_dequant_scale(x, ib, subblock);
311323
const uint8_t qb = x[ib].qs[il / 4];
312324
const int shift = 2 * (il & 3);
313325
const uint8_t q0 = ((qb >> shift) & 0x3u) | (((x[ib].hmask[il / 8] >> (il & 7)) & 0x1u) << 2);
@@ -323,7 +335,7 @@ __device__ __forceinline__ float2 pq_tq_dequant_pair<PqTqTypeTag::P4_K>(const vo
323335
const int ib = global_pair / pairs_per_block;
324336
const int ip = global_pair % pairs_per_block;
325337
const int il = 2 * ip;
326-
const float scale = pqk_dequant_scale(x, ib, il / GGML_PQK_SUBBLOCK_SIZE);
338+
const float scale = pq4_k_dequant_scale(x, ib, il / GGML_PQ4_K_SUBBLOCK_SIZE);
327339
const uint8_t qb = x[ib].qs[ip];
328340
return make_float2(ggml_pqk_centroid_4bit(qb & 0xFu) * scale, ggml_pqk_centroid_4bit(qb >> 4) * scale);
329341
}

ggml/src/ggml-cuda/vecdotq.cuh

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -973,17 +973,17 @@ static __device__ __forceinline__ float vec_dot_pq3_K_q8_1(
973973
const block_pq3_K * bq = (const block_pq3_K *) vbq + kbx;
974974

975975
const int sub = iqs;
976-
const int q8_block = sub >> 1;
977-
const int q8_i32 = (sub & 1) * 4;
978-
const int band = sub / PQK_MMVQ_SUBBLOCKS_PER_BAND;
979-
const int elem_base = sub * PQK_MMVQ_SUBBLOCK_SIZE;
976+
const int q8_block = sub >> 2;
977+
const int q8_i32 = (sub & 3) * 2;
978+
const int band = sub / GGML_PQ3_K_SUBBLOCKS_PER_BAND;
979+
const int elem_base = sub * GGML_PQ3_K_SUBBLOCK_SIZE;
980980

981981
int sumi = 0;
982982
#pragma unroll
983-
for (int i = 0; i < 4; ++i) {
983+
for (int i = 0; i < 2; ++i) {
984984
const int elem = elem_base + 4*i;
985985
const uint8_t high = (bq->hmask[elem >> 3] >> (elem & 7)) & 0x0Fu;
986-
const uint8_t qb = bq->qs[4*sub + i];
986+
const uint8_t qb = bq->qs[2*sub + i];
987987
const int q4 = ((((qb >> 0) & 0x03u) | ((high & 0x01u) << 2)) << 0)
988988
| ((((qb >> 2) & 0x03u) | ((high & 0x02u) << 1)) << 4)
989989
| ((((qb >> 4) & 0x03u) | (high & 0x04u)) << 8)
@@ -994,8 +994,8 @@ static __device__ __forceinline__ float vec_dot_pq3_K_q8_1(
994994
sumi = ggml_cuda_dp4a(v, u, sumi);
995995
}
996996

997-
const uint8_t qscale = pqk_vec_scale_get(bq->scales, sub);
998-
const float d = __half2float(bq->d[band]) * PQK_LOCAL_SCALE_LUT[qscale] * PQK_DP4A_INV_SCALE_3BIT;
997+
const uint8_t qscale = ggml_pq3_k_scale_get(bq->scales, sub);
998+
const float d = __half2float(bq->d[band]) * PQ3_K_LOCAL_SCALE_LUT[qscale] * PQK_DP4A_INV_SCALE_3BIT;
999999
return d * __low2float(bq8_1[q8_block].ds) * sumi;
10001000
}
10011001

@@ -1005,22 +1005,22 @@ static __device__ __forceinline__ float vec_dot_pq4_K_q8_1(
10051005
const block_pq4_K * bq = (const block_pq4_K *) vbq + kbx;
10061006

10071007
const int sub = iqs;
1008-
const int q8_block = sub >> 1;
1009-
const int q8_i32 = (sub & 1) * 4;
1010-
const int band = sub / PQK_MMVQ_SUBBLOCKS_PER_BAND;
1008+
const int q8_block = sub >> 2;
1009+
const int q8_i32 = (sub & 3) * 2;
1010+
const int band = sub / GGML_PQ4_K_SUBBLOCKS_PER_BAND;
10111011
const uint16_t * q16 = (const uint16_t *) bq->qs;
10121012

10131013
int sumi = 0;
10141014
#pragma unroll
1015-
for (int i = 0; i < 4; ++i) {
1016-
const int2 vp = get_int_from_table_16((int) q16[4*sub + i], PQK_DP4A_VAL_4BIT);
1015+
for (int i = 0; i < 2; ++i) {
1016+
const int2 vp = get_int_from_table_16((int) q16[2*sub + i], PQK_DP4A_VAL_4BIT);
10171017
const int v = __byte_perm(vp.x, vp.y, 0x5140);
10181018
const int u = get_int_b4(bq8_1[q8_block].qs, q8_i32 + i);
10191019
sumi = ggml_cuda_dp4a(v, u, sumi);
10201020
}
10211021

1022-
const uint8_t qscale = pqk_vec_scale_get(bq->scales, sub);
1023-
const float d = __half2float(bq->d[band]) * PQK_LOCAL_SCALE_LUT[qscale] * PQK_DP4A_INV_SCALE_4BIT;
1022+
const uint8_t qscale = ggml_pq4_k_scale_get(bq->scales, sub);
1023+
const float d = __half2float(bq->d[band]) * PQ4_K_LOCAL_SCALE_LUT[qscale] * PQK_DP4A_INV_SCALE_4BIT;
10241024
return d * __low2float(bq8_1[q8_block].ds) * sumi;
10251025
}
10261026

0 commit comments

Comments
 (0)