@@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned(
8585 (const int *)(x8 + sizeof (int ) * i32 )); // assume at least 4 byte alignment
8686}
8787
88+ static __dpct_inline__ int byte_sub_4 (const int a, const int b) {
89+ const uint32_t ua = static_cast <uint32_t >(a);
90+ const uint32_t ub = static_cast <uint32_t >(b);
91+ return static_cast <int >(((ua | 0x80808080u ) - ub) ^ 0x80808080u );
92+ }
93+
94+ static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar (
95+ const int vl, const int vh, const int u0, const int u1, const int8_t sc0,
96+ const int8_t sc1, const float d, const float d80, const float d81) {
97+ static_assert (QR6_K == 2 , " q6_K MMVQ scalar fast path assumes QR6_K == 2" );
98+
99+ const int vil0 = (vl >> 0 ) & 0x0F0F0F0F ;
100+ const int vih0 = ((vh >> 0 ) << 4 ) & 0x30303030 ;
101+ const int vi0 = byte_sub_4 (vil0 | vih0, 0x20202020 );
102+
103+ const int vil1 = (vl >> 4 ) & 0x0F0F0F0F ;
104+ const int vih1 = ((vh >> 4 ) << 4 ) & 0x30303030 ;
105+ const int vi1 = byte_sub_4 (vil1 | vih1, 0x20202020 );
106+
107+ const float sumf =
108+ d80 * (dpct::dp4a (vi0, u0, 0 ) * sc0) +
109+ d81 * (dpct::dp4a (vi1, u1, 0 ) * sc1);
110+
111+ return d * sumf;
112+ }
113+
88114static __dpct_inline__ void get_int_from_table_16 (const uint32_t &q4,
89115 const uint8_t *values,
90116 int &val1, int &val2) {
@@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
279305 const int *__restrict__ u,
280306 const int8_t *__restrict__ scales, const float &d,
281307 const float *__restrict__ d8) {
282-
283- float sumf = 0 .0f ;
284-
285- #pragma unroll
286- for (int i = 0 ; i < QR6_K ; ++i) {
287- const int sc = scales[4 *i];
288-
289- const int vil = (vl >> (4 *i)) & 0x0F0F0F0F ;
290-
291- const int vih = ((vh >> (4 *i)) << 4 ) & 0x30303030 ;
292-
293- const int vi = dpct::vectorized_binary<sycl::char4>(
294- (vil | vih), 0x20202020 , dpct::sub_sat ()); // vi = (vil | vih) - 32
295-
296- sumf += d8[i] * (dpct::dp4a (vi, u[i], 0 ) * sc); // SIMD dot product
297- }
298-
299- return d*sumf;
308+ return vec_dot_q6_K_q8_1_impl_mmvq_scalar (
309+ vl, vh, u[0 ], u[1 ], scales[0 ], scales[4 ], d, d8[0 ], d8[1 ]);
300310}
301311
302312// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
@@ -542,23 +552,8 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
542552 __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq (const int vl, const int vh, const int * __restrict__ u,
543553 const int8_t * __restrict__ scales, const float d,
544554 const float * __restrict__ d8) {
545- float sumf = 0 .0f ;
546-
547- #pragma unroll
548- for (int i = 0 ; i < QR6_K ; ++i) {
549- const int sc = scales[4 * i];
550-
551- const int vil = (vl >> (4 * i)) & 0x0F0F0F0F ;
552-
553- const int vih = ((vh >> (4 * i)) << 4 ) & 0x30303030 ;
554-
555- const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020 ,
556- dpct::sub_sat ()); // vi = (vil | vih) - 32
557-
558- sumf += d8[i] * (dpct::dp4a (vi, u[i], 0 ) * sc); // SIMD dot product
559- }
560-
561- return d * sumf;
555+ return vec_dot_q6_K_q8_1_impl_mmvq_scalar (
556+ vl, vh, u[0 ], u[1 ], scales[0 ], scales[4 ], d, d8[0 ], d8[1 ]);
562557 }
563558
564559 __dpct_inline__ float operator ()(const void * __restrict__ vbq, const std::pair<int , int > ibx_offset,
@@ -579,16 +574,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
579574
580575 const int8_t * scs = scales + scale_offset;
581576
582- int u[QR6_K ];
583- float d8[QR6_K ];
577+ const int u0 = get_int_from_int8_aligned (
578+ q8_1_quant_ptr + bq8_offset * QK8_1 , iqs % QI8_1 );
579+ const int u1 = get_int_from_int8_aligned (
580+ q8_1_quant_ptr + (bq8_offset + 2 ) * QK8_1 , iqs % QI8_1 );
581+ const float d80 = (*(q8_1_ds + bq8_offset + 0 ))[0 ];
582+ const float d81 = (*(q8_1_ds + bq8_offset + 2 ))[0 ];
584583
585- #pragma unroll
586- for (int i = 0 ; i < QR6_K ; ++i) {
587- u[i] = get_int_from_int8_aligned (q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1 , iqs % QI8_1 );
588- const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
589- d8[i] = ds_values[0 ];
590- }
591- return vec_dot_q6_K_q8_1_impl_mmvq (vl, vh, u, scs, *d, d8);
584+ return vec_dot_q6_K_q8_1_impl_mmvq_scalar (
585+ vl, vh, u0, u1, scs[0 ], scs[4 ], *d, d80, d81);
592586 }
593587};
594588#define VDR_Q4_0_Q8_1_MMVQ 2
@@ -1167,16 +1161,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq,
11671161
11681162 const int8_t * scales = bq6_K->scales + scale_offset;
11691163
1170- int u[QR6_K ];
1171- float d8[QR6_K ];
1172-
1173- #pragma unroll
1174- for (int i = 0 ; i < QR6_K ; ++i) {
1175- u[i] = get_int_from_int8_aligned (bq8_1[bq8_offset + 2 *i].qs , iqs % QI8_1 );
1176- d8[i] = bq8_1[bq8_offset + 2 * i].ds [0 ];
1177- }
1164+ const int u0 = get_int_from_int8_aligned (
1165+ bq8_1[bq8_offset + 0 ].qs , iqs % QI8_1 );
1166+ const int u1 = get_int_from_int8_aligned (
1167+ bq8_1[bq8_offset + 2 ].qs , iqs % QI8_1 );
1168+ const float d80 = bq8_1[bq8_offset + 0 ].ds [0 ];
1169+ const float d81 = bq8_1[bq8_offset + 2 ].ds [0 ];
11781170
1179- return vec_dot_q6_K_q8_1_impl_mmvq (vl, vh, u, scales, bq6_K->d , d8);
1171+ return vec_dot_q6_K_q8_1_impl_mmvq_scalar (
1172+ vl, vh, u0, u1, scales[0 ], scales[4 ], bq6_K->d , d80, d81);
11801173}
11811174
11821175
0 commit comments