Skip to content

Commit 0caf2a1

Browse files
aicss-genaictao456
andauthored
sycl: scalar SWAR byte-subtract in Q6_K MMVQ dot product (#22156)
Signed-off-by: Chun Tao <chun.tao@intel.com> Co-authored-by: Chun Tao <chun.tao@intel.com>
1 parent 5511965 commit 0caf2a1

1 file changed

Lines changed: 46 additions & 53 deletions

File tree

ggml/src/ggml-sycl/vecdotq.hpp

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned(
8585
(const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
8686
}
8787

88+
static __dpct_inline__ int byte_sub_4(const int a, const int b) {
89+
const uint32_t ua = static_cast<uint32_t>(a);
90+
const uint32_t ub = static_cast<uint32_t>(b);
91+
return static_cast<int>(((ua | 0x80808080u) - ub) ^ 0x80808080u);
92+
}
93+
94+
static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar(
95+
const int vl, const int vh, const int u0, const int u1, const int8_t sc0,
96+
const int8_t sc1, const float d, const float d80, const float d81) {
97+
static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2");
98+
99+
const int vil0 = (vl >> 0) & 0x0F0F0F0F;
100+
const int vih0 = ((vh >> 0) << 4) & 0x30303030;
101+
const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020);
102+
103+
const int vil1 = (vl >> 4) & 0x0F0F0F0F;
104+
const int vih1 = ((vh >> 4) << 4) & 0x30303030;
105+
const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020);
106+
107+
const float sumf =
108+
d80 * (dpct::dp4a(vi0, u0, 0) * sc0) +
109+
d81 * (dpct::dp4a(vi1, u1, 0) * sc1);
110+
111+
return d * sumf;
112+
}
113+
88114
static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
89115
const uint8_t *values,
90116
int &val1, int &val2) {
@@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
279305
const int *__restrict__ u,
280306
const int8_t *__restrict__ scales, const float &d,
281307
const float *__restrict__ d8) {
282-
283-
float sumf = 0.0f;
284-
285-
#pragma unroll
286-
for (int i = 0; i < QR6_K; ++i) {
287-
const int sc = scales[4*i];
288-
289-
const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
290-
291-
const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
292-
293-
const int vi = dpct::vectorized_binary<sycl::char4>(
294-
(vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32
295-
296-
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
297-
}
298-
299-
return d*sumf;
308+
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
309+
vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]);
300310
}
301311

302312
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
@@ -542,23 +552,8 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
542552
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
543553
const int8_t * __restrict__ scales, const float d,
544554
const float * __restrict__ d8) {
545-
float sumf = 0.0f;
546-
547-
#pragma unroll
548-
for (int i = 0; i < QR6_K; ++i) {
549-
const int sc = scales[4 * i];
550-
551-
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
552-
553-
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
554-
555-
const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
556-
dpct::sub_sat()); // vi = (vil | vih) - 32
557-
558-
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
559-
}
560-
561-
return d * sumf;
555+
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
556+
vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]);
562557
}
563558

564559
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
@@ -579,16 +574,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
579574

580575
const int8_t * scs = scales + scale_offset;
581576

582-
int u[QR6_K];
583-
float d8[QR6_K];
577+
const int u0 = get_int_from_int8_aligned(
578+
q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1);
579+
const int u1 = get_int_from_int8_aligned(
580+
q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1);
581+
const float d80 = (*(q8_1_ds + bq8_offset + 0))[0];
582+
const float d81 = (*(q8_1_ds + bq8_offset + 2))[0];
584583

585-
#pragma unroll
586-
for (int i = 0; i < QR6_K; ++i) {
587-
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
588-
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
589-
d8[i] = ds_values[0];
590-
}
591-
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
584+
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
585+
vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81);
592586
}
593587
};
594588
#define VDR_Q4_0_Q8_1_MMVQ 2
@@ -1167,16 +1161,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq,
11671161

11681162
const int8_t * scales = bq6_K->scales + scale_offset;
11691163

1170-
int u[QR6_K];
1171-
float d8[QR6_K];
1172-
1173-
#pragma unroll
1174-
for (int i = 0; i < QR6_K; ++i) {
1175-
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
1176-
d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];
1177-
}
1164+
const int u0 = get_int_from_int8_aligned(
1165+
bq8_1[bq8_offset + 0].qs, iqs % QI8_1);
1166+
const int u1 = get_int_from_int8_aligned(
1167+
bq8_1[bq8_offset + 2].qs, iqs % QI8_1);
1168+
const float d80 = bq8_1[bq8_offset + 0].ds[0];
1169+
const float d81 = bq8_1[bq8_offset + 2].ds[0];
11781170

1179-
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
1171+
return vec_dot_q6_K_q8_1_impl_mmvq_scalar(
1172+
vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81);
11801173
}
11811174

11821175

0 commit comments

Comments
 (0)