@@ -2300,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
23002300
23012301#if defined __AVX2__
23022302
2303- const __m256i m4 = _mm256_set1_epi8 (0xF );
2304- const __m256i m2 = _mm256_set1_epi8 (3 );
2305- const __m256i m32s = _mm256_set1_epi8 (32 );
2303+ const __m256i m3 = _mm256_set1_epi8 (3 );
2304+ const __m256i m15 = _mm256_set1_epi8 (15 );
23062305
23072306 __m256 acc = _mm256_setzero_ps ();
23082307
@@ -2314,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
23142313 const uint8_t * GGML_RESTRICT qh = x [i ].qh ;
23152314 const int8_t * GGML_RESTRICT q8 = y [i ].qs ;
23162315
2316+ const __m256i q8sums = _mm256_loadu_si256 ((const __m256i * )y [i ].bsums );
23172317 const __m128i scales = _mm_loadu_si128 ((const __m128i * )x [i ].scales );
2318+ const __m256i scales_16 = _mm256_cvtepi8_epi16 (scales );
2319+ const __m256i q8sclsub = _mm256_slli_epi32 (_mm256_madd_epi16 (q8sums , scales_16 ), 5 );
23182320
23192321 __m256i sumi = _mm256_setzero_si256 ();
23202322
23212323 int is = 0 ;
23222324
23232325 for (int j = 0 ; j < QK_K /128 ; ++ j ) {
2324-
2325- const __m128i scale_0 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 0 ));
2326- const __m128i scale_1 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 1 ));
2327- const __m128i scale_2 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 2 ));
2328- const __m128i scale_3 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 3 ));
2329- is += 4 ;
2330-
23312326 const __m256i q4bits1 = _mm256_loadu_si256 ((const __m256i * )q4 ); q4 += 32 ;
23322327 const __m256i q4bits2 = _mm256_loadu_si256 ((const __m256i * )q4 ); q4 += 32 ;
23332328 const __m256i q4bitsH = _mm256_loadu_si256 ((const __m256i * )qh ); qh += 32 ;
23342329
2335- const __m256i q4h_0 = _mm256_slli_epi16 (_mm256_and_si256 (q4bitsH , m2 ), 4 );
2336- const __m256i q4h_1 = _mm256_slli_epi16 (_mm256_and_si256 (_mm256_srli_epi16 ( q4bitsH , 2 ), m2 ), 4 );
2337- const __m256i q4h_2 = _mm256_slli_epi16 ( _mm256_and_si256 (_mm256_srli_epi16 ( q4bitsH , 4 ), m2 ), 4 );
2338- const __m256i q4h_3 = _mm256_slli_epi16 (_mm256_and_si256 (_mm256_srli_epi16 ( q4bitsH , 6 ), m2 ), 4 );
2330+ const __m256i q4h_0 = _mm256_slli_epi16 (_mm256_and_si256 (q4bitsH , m3 ), 4 );
2331+ const __m256i q4h_1 = _mm256_slli_epi16 (_mm256_and_si256 (q4bitsH , _mm256_set1_epi8 ( 12 ) ), 2 );
2332+ const __m256i q4h_2 = _mm256_and_si256 (q4bitsH , _mm256_set1_epi8 ( 48 ) );
2333+ const __m256i q4h_3 = _mm256_srli_epi16 (_mm256_and_si256 (q4bitsH , _mm256_set1_epi8 ( -64 ) ), 2 );
23392334
2340- const __m256i q4_0 = _mm256_or_si256 (_mm256_and_si256 (q4bits1 , m4 ), q4h_0 );
2341- const __m256i q4_1 = _mm256_or_si256 (_mm256_and_si256 (q4bits2 , m4 ), q4h_1 );
2342- const __m256i q4_2 = _mm256_or_si256 (_mm256_and_si256 (_mm256_srli_epi16 (q4bits1 , 4 ), m4 ), q4h_2 );
2343- const __m256i q4_3 = _mm256_or_si256 (_mm256_and_si256 (_mm256_srli_epi16 (q4bits2 , 4 ), m4 ), q4h_3 );
2335+ const __m256i q4_0 = _mm256_or_si256 (_mm256_and_si256 (q4bits1 , m15 ), q4h_0 );
2336+ const __m256i q4_1 = _mm256_or_si256 (_mm256_and_si256 (q4bits2 , m15 ), q4h_1 );
2337+ const __m256i q4_2 = _mm256_or_si256 (_mm256_and_si256 (_mm256_srli_epi16 (q4bits1 , 4 ), m15 ), q4h_2 );
2338+ const __m256i q4_3 = _mm256_or_si256 (_mm256_and_si256 (_mm256_srli_epi16 (q4bits2 , 4 ), m15 ), q4h_3 );
23442339
23452340 const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
23462341 const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
23472342 const __m256i q8_2 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
23482343 const __m256i q8_3 = _mm256_loadu_si256 ((const __m256i * )q8 ); q8 += 32 ;
23492344
2350- __m256i q8s_0 = _mm256_maddubs_epi16 (m32s , q8_0 );
2351- __m256i q8s_1 = _mm256_maddubs_epi16 (m32s , q8_1 );
2352- __m256i q8s_2 = _mm256_maddubs_epi16 (m32s , q8_2 );
2353- __m256i q8s_3 = _mm256_maddubs_epi16 (m32s , q8_3 );
2354-
23552345 __m256i p16_0 = _mm256_maddubs_epi16 (q4_0 , q8_0 );
23562346 __m256i p16_1 = _mm256_maddubs_epi16 (q4_1 , q8_1 );
23572347 __m256i p16_2 = _mm256_maddubs_epi16 (q4_2 , q8_2 );
23582348 __m256i p16_3 = _mm256_maddubs_epi16 (q4_3 , q8_3 );
23592349
2360- p16_0 = _mm256_sub_epi16 (p16_0 , q8s_0 );
2361- p16_1 = _mm256_sub_epi16 (p16_1 , q8s_1 );
2362- p16_2 = _mm256_sub_epi16 (p16_2 , q8s_2 );
2363- p16_3 = _mm256_sub_epi16 (p16_3 , q8s_3 );
2350+ const __m128i scale_0 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 0 ));
2351+ const __m128i scale_1 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 1 ));
2352+ const __m128i scale_2 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 2 ));
2353+ const __m128i scale_3 = _mm_shuffle_epi8 (scales , get_scale_shuffle (is + 3 ));
2354+ is += 4 ;
23642355
23652356 p16_0 = _mm256_madd_epi16 (_mm256_cvtepi8_epi16 (scale_0 ), p16_0 );
23662357 p16_1 = _mm256_madd_epi16 (_mm256_cvtepi8_epi16 (scale_1 ), p16_1 );
@@ -2372,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
23722363
23732364 }
23742365
2366+ sumi = _mm256_sub_epi32 (sumi , q8sclsub );
23752367 acc = _mm256_fmadd_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (sumi ), acc );
23762368 }
23772369
0 commit comments