@@ -708,44 +708,45 @@ void ggml_vec_dot_q1_0_g128_q8_0(int n, float * GGML_RESTRICT s, size_t bs, cons
708708 const block_q8_0 * GGML_RESTRICT y = vy ;
709709
710710#if defined(__AVX2__ )
711- // AVX2: expand each 32-bit sign stream to a byte mask, sign-flip qy
712- // directly in the byte domain, then reduce two Q8_0 sub-blocks in
713- // parallel before folding the pair into the outer block sum.
714711 const __m256i ones_8 = _mm256_set1_epi8 (1 );
715712 const __m256i ones_16 = _mm256_set1_epi16 (1 );
713+ const __m256i byte_shuf = _mm256_setr_epi8 (
714+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
715+ 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 );
716+ const __m256i bit_masks = _mm256_setr_epi8 (
717+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 ,
718+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 );
716719 const __m256i zero = _mm256_setzero_si256 ();
717720 __m256 acc = _mm256_setzero_ps ();
718721
719722 for (int ib = 0 ; ib < nb ; ++ ib ) {
720723 const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
721- __m256 acc_block_0 = _mm256_setzero_ps ();
722- __m256 acc_block_1 = _mm256_setzero_ps ();
723-
724- for (int k = 0 ; k < 4 ; k += 2 ) {
725- const block_q8_0 * GGML_RESTRICT yb_0 = & y [ib * 4 + k + 0 ];
726- const block_q8_0 * GGML_RESTRICT yb_1 = & y [ib * 4 + k + 1 ];
727- const __m256i bit_mask_0 = bytes_from_bits_32 (& x [ib ].qs [(k + 0 ) * 4 ]);
728- const __m256i bit_mask_1 = bytes_from_bits_32 (& x [ib ].qs [(k + 1 ) * 4 ]);
729- const __m256i qy_0 = _mm256_loadu_si256 ((const __m256i * ) yb_0 -> qs );
730- const __m256i qy_1 = _mm256_loadu_si256 ((const __m256i * ) yb_1 -> qs );
731- const __m256i sign_mask_0 = _mm256_cmpeq_epi8 (bit_mask_0 , zero );
732- const __m256i sign_mask_1 = _mm256_cmpeq_epi8 (bit_mask_1 , zero );
733- const __m256i sy_0 = _mm256_sub_epi8 (_mm256_xor_si256 (qy_0 , sign_mask_0 ), sign_mask_0 );
734- const __m256i sy_1 = _mm256_sub_epi8 (_mm256_xor_si256 (qy_1 , sign_mask_1 ), sign_mask_1 );
735- const __m256i sum16_0 = _mm256_maddubs_epi16 (ones_8 , sy_0 );
736- const __m256i sum16_1 = _mm256_maddubs_epi16 (ones_8 , sy_1 );
737- const __m256i sum32_0 = _mm256_madd_epi16 (sum16_0 , ones_16 );
738- const __m256i sum32_1 = _mm256_madd_epi16 (sum16_1 , ones_16 );
739- const __m256 q_0 = _mm256_cvtepi32_ps (sum32_0 );
740- const __m256 q_1 = _mm256_cvtepi32_ps (sum32_1 );
741- const __m256 d1_0 = _mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (yb_0 -> d ));
742- const __m256 d1_1 = _mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (yb_1 -> d ));
743-
744- acc_block_0 = _mm256_fmadd_ps (d1_0 , q_0 , acc_block_0 );
745- acc_block_1 = _mm256_fmadd_ps (d1_1 , q_1 , acc_block_1 );
746- }
724+ const uint32_t * GGML_RESTRICT qs32 = (const uint32_t * ) x [ib ].qs ;
725+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
747726
748- acc = _mm256_fmadd_ps (_mm256_set1_ps (d0 ), _mm256_add_ps (acc_block_0 , acc_block_1 ), acc );
727+ __m256 acc_block ;
728+ {
729+ const __m256i qy = _mm256_loadu_si256 ((const __m256i * ) y_ptr [0 ].qs );
730+ const __m256i sm = _mm256_cmpeq_epi8 (
731+ _mm256_and_si256 (_mm256_shuffle_epi8 (_mm256_set1_epi32 ((int ) qs32 [0 ]), byte_shuf ), bit_masks ), zero );
732+ const __m256i sy = _mm256_sub_epi8 (_mm256_xor_si256 (qy , sm ), sm );
733+ const __m256i s32 = _mm256_madd_epi16 (_mm256_maddubs_epi16 (ones_8 , sy ), ones_16 );
734+ acc_block = _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [0 ].d )), _mm256_cvtepi32_ps (s32 ));
735+ }
736+ #define Q1_AVX2_BLOCK (K ) \
737+ { \
738+ const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \
739+ const __m256i sm = _mm256_cmpeq_epi8( \
740+ _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); \
741+ const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); \
742+ const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \
743+ acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \
744+ }
745+ Q1_AVX2_BLOCK (1 )
746+ Q1_AVX2_BLOCK (2 )
747+ Q1_AVX2_BLOCK (3 )
748+ #undef Q1_AVX2_BLOCK
749+ acc = _mm256_fmadd_ps (_mm256_set1_ps (d0 ), acc_block , acc );
749750 }
750751
751752 * s = hsum_float_8 (acc );
0 commit comments