Skip to content

Commit 2dd8416

Browse files
authored
ggml-cpu: optimize avx2 q6_k (#22345)
1 parent f454bd7 commit 2dd8416

1 file changed

Lines changed: 19 additions & 27 deletions

File tree

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
23002300

23012301
#if defined __AVX2__
23022302

2303-
const __m256i m4 = _mm256_set1_epi8(0xF);
2304-
const __m256i m2 = _mm256_set1_epi8(3);
2305-
const __m256i m32s = _mm256_set1_epi8(32);
2303+
const __m256i m3 = _mm256_set1_epi8(3);
2304+
const __m256i m15 = _mm256_set1_epi8(15);
23062305

23072306
__m256 acc = _mm256_setzero_ps();
23082307

@@ -2314,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
23142313
const uint8_t * GGML_RESTRICT qh = x[i].qh;
23152314
const int8_t * GGML_RESTRICT q8 = y[i].qs;
23162315

2316+
const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
23172317
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2318+
const __m256i scales_16 = _mm256_cvtepi8_epi16(scales);
2319+
const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5);
23182320

23192321
__m256i sumi = _mm256_setzero_si256();
23202322

23212323
int is = 0;
23222324

23232325
for (int j = 0; j < QK_K/128; ++j) {
2324-
2325-
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2326-
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2327-
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2328-
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2329-
is += 4;
2330-
23312326
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
23322327
const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
23332328
const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
23342329

2335-
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
2336-
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
2337-
const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
2338-
const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
2330+
const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4);
2331+
const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2);
2332+
const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48));
2333+
const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2);
23392334

2340-
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
2341-
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
2342-
const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
2343-
const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
2335+
const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0);
2336+
const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1);
2337+
const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2);
2338+
const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3);
23442339

23452340
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
23462341
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
23472342
const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
23482343
const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
23492344

2350-
__m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
2351-
__m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
2352-
__m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
2353-
__m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
2354-
23552345
__m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
23562346
__m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
23572347
__m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
23582348
__m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
23592349

2360-
p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
2361-
p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
2362-
p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
2363-
p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
2350+
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2351+
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2352+
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2353+
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2354+
is += 4;
23642355

23652356
p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
23662357
p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
@@ -2372,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
23722363

23732364
}
23742365

2366+
sumi = _mm256_sub_epi32(sumi, q8sclsub);
23752367
acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
23762368
}
23772369

0 commit comments

Comments
 (0)