Skip to content

Commit 167652c

Browse files
committed
Replaced q1_0_g128 AVX2 with zcattacz's code
1 parent 93e192f commit 167652c

1 file changed

Lines changed: 31 additions & 30 deletions

File tree

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -708,44 +708,45 @@ void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, cons
708708
const block_q8_0 * GGML_RESTRICT y = vy;
709709

710710
#if defined(__AVX2__)
711-
// AVX2: expand each 32-bit sign stream to a byte mask, sign-flip qy
712-
// directly in the byte domain, then reduce two Q8_0 sub-blocks in
713-
// parallel before folding the pair into the outer block sum.
714711
const __m256i ones_8 = _mm256_set1_epi8(1);
715712
const __m256i ones_16 = _mm256_set1_epi16(1);
713+
const __m256i byte_shuf = _mm256_setr_epi8(
714+
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
715+
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3);
716+
const __m256i bit_masks = _mm256_setr_epi8(
717+
1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128,
718+
1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128);
716719
const __m256i zero = _mm256_setzero_si256();
717720
__m256 acc = _mm256_setzero_ps();
718721

719722
for (int ib = 0; ib < nb; ++ib) {
720723
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
721-
__m256 acc_block_0 = _mm256_setzero_ps();
722-
__m256 acc_block_1 = _mm256_setzero_ps();
723-
724-
for (int k = 0; k < 4; k += 2) {
725-
const block_q8_0 * GGML_RESTRICT yb_0 = &y[ib * 4 + k + 0];
726-
const block_q8_0 * GGML_RESTRICT yb_1 = &y[ib * 4 + k + 1];
727-
const __m256i bit_mask_0 = bytes_from_bits_32(&x[ib].qs[(k + 0) * 4]);
728-
const __m256i bit_mask_1 = bytes_from_bits_32(&x[ib].qs[(k + 1) * 4]);
729-
const __m256i qy_0 = _mm256_loadu_si256((const __m256i *) yb_0->qs);
730-
const __m256i qy_1 = _mm256_loadu_si256((const __m256i *) yb_1->qs);
731-
const __m256i sign_mask_0 = _mm256_cmpeq_epi8(bit_mask_0, zero);
732-
const __m256i sign_mask_1 = _mm256_cmpeq_epi8(bit_mask_1, zero);
733-
const __m256i sy_0 = _mm256_sub_epi8(_mm256_xor_si256(qy_0, sign_mask_0), sign_mask_0);
734-
const __m256i sy_1 = _mm256_sub_epi8(_mm256_xor_si256(qy_1, sign_mask_1), sign_mask_1);
735-
const __m256i sum16_0 = _mm256_maddubs_epi16(ones_8, sy_0);
736-
const __m256i sum16_1 = _mm256_maddubs_epi16(ones_8, sy_1);
737-
const __m256i sum32_0 = _mm256_madd_epi16(sum16_0, ones_16);
738-
const __m256i sum32_1 = _mm256_madd_epi16(sum16_1, ones_16);
739-
const __m256 q_0 = _mm256_cvtepi32_ps(sum32_0);
740-
const __m256 q_1 = _mm256_cvtepi32_ps(sum32_1);
741-
const __m256 d1_0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb_0->d));
742-
const __m256 d1_1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(yb_1->d));
743-
744-
acc_block_0 = _mm256_fmadd_ps(d1_0, q_0, acc_block_0);
745-
acc_block_1 = _mm256_fmadd_ps(d1_1, q_1, acc_block_1);
746-
}
724+
const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs;
725+
const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4];
747726

748-
acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), _mm256_add_ps(acc_block_0, acc_block_1), acc);
727+
__m256 acc_block;
728+
{
729+
const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs);
730+
const __m256i sm = _mm256_cmpeq_epi8(
731+
_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero);
732+
const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm);
733+
const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16);
734+
acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32));
735+
}
736+
#define Q1_AVX2_BLOCK(K) \
737+
{ \
738+
const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \
739+
const __m256i sm = _mm256_cmpeq_epi8( \
740+
_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); \
741+
const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); \
742+
const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \
743+
acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \
744+
}
745+
Q1_AVX2_BLOCK(1)
746+
Q1_AVX2_BLOCK(2)
747+
Q1_AVX2_BLOCK(3)
748+
#undef Q1_AVX2_BLOCK
749+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc);
749750
}
750751

751752
*s = hsum_float_8(acc);

0 commit comments

Comments
 (0)