@@ -274,6 +274,25 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const
274274}
275275#endif
276276#elif defined(__SSSE3__ )
277+ static inline int hsum_i32_4 (const __m128i a ) {
278+ const __m128i hi64 = _mm_unpackhi_epi64 (a , a );
279+ const __m128i sum64 = _mm_add_epi32 (hi64 , a );
280+ const __m128i hi32 = _mm_shuffle_epi32 (sum64 , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
281+ return _mm_cvtsi128_si32 (_mm_add_epi32 (sum64 , hi32 ));
282+ }
283+
284+ static inline __m128i bytes_from_bits_16 (const uint8_t * x ) {
285+ uint16_t x16 ;
286+ memcpy (& x16 , x , sizeof (uint16_t ));
287+
288+ const __m128i shuf_mask = _mm_set_epi64x (0x0101010101010101 , 0x0000000000000000 );
289+ __m128i bytes = _mm_shuffle_epi8 (_mm_set1_epi16 ((short ) x16 ), shuf_mask );
290+ const __m128i bit_mask = _mm_set_epi64x (0x7fbfdfeff7fbfdfe , 0x7fbfdfeff7fbfdfe );
291+ bytes = _mm_or_si128 (bytes , bit_mask );
292+
293+ return _mm_cmpeq_epi8 (bytes , _mm_set1_epi64x (-1 ));
294+ }
295+
277296// horizontally add 4x4 floats
278297static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
279298 __m128 res_0 = _mm_hadd_ps (a , b );
@@ -540,6 +559,174 @@ static inline __m128i get_scale_shuffle(int i) {
540559}
541560#endif
542561
562+ void ggml_vec_dot_q1_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
563+ const int qk = QK1_0 ;
564+ const int nb = n / qk ;
565+
566+ assert (n % qk == 0 );
567+ assert (nrc == 1 );
568+ UNUSED (nrc );
569+ UNUSED (bx );
570+ UNUSED (by );
571+ UNUSED (bs );
572+
573+ const block_q1_0 * GGML_RESTRICT x = vx ;
574+ const block_q8_0 * GGML_RESTRICT y = vy ;
575+
576+ #if defined(__AVX2__ )
577+ const __m256i ones_8 = _mm256_set1_epi8 (1 );
578+ const __m256i ones_16 = _mm256_set1_epi16 (1 );
579+ const __m256i byte_shuf = _mm256_setr_epi8 (
580+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
581+ 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 );
582+ const __m256i bit_masks = _mm256_setr_epi8 (
583+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 ,
584+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 );
585+ const __m256i zero = _mm256_setzero_si256 ();
586+ __m256 acc = _mm256_setzero_ps ();
587+
588+ for (int ib = 0 ; ib < nb ; ++ ib ) {
589+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
590+ const uint32_t * GGML_RESTRICT qs32 = (const uint32_t * ) x [ib ].qs ;
591+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
592+
593+ __m256 acc_block ;
594+ {
595+ const __m256i qy = _mm256_loadu_si256 ((const __m256i * ) y_ptr [0 ].qs );
596+ const __m256i sm = _mm256_cmpeq_epi8 (
597+ _mm256_and_si256 (_mm256_shuffle_epi8 (_mm256_set1_epi32 ((int ) qs32 [0 ]), byte_shuf ), bit_masks ), zero );
598+ const __m256i sy = _mm256_sub_epi8 (_mm256_xor_si256 (qy , sm ), sm );
599+ const __m256i s32 = _mm256_madd_epi16 (_mm256_maddubs_epi16 (ones_8 , sy ), ones_16 );
600+ acc_block = _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [0 ].d )), _mm256_cvtepi32_ps (s32 ));
601+ }
602+ #define Q1_AVX2_BLOCK (K ) \
603+ { \
604+ const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \
605+ const __m256i sm = _mm256_cmpeq_epi8( \
606+ _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); \
607+ const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); \
608+ const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \
609+ acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \
610+ }
611+ Q1_AVX2_BLOCK (1 )
612+ Q1_AVX2_BLOCK (2 )
613+ Q1_AVX2_BLOCK (3 )
614+ #undef Q1_AVX2_BLOCK
615+ acc = _mm256_fmadd_ps (_mm256_set1_ps (d0 ), acc_block , acc );
616+ }
617+
618+ * s = hsum_float_8 (acc );
619+ #elif defined(__AVX__ )
620+ const __m128i ones_8 = _mm_set1_epi8 (1 );
621+ const __m128i ones_16 = _mm_set1_epi16 (1 );
622+ const __m128i zero = _mm_setzero_si128 ();
623+ __m256 acc = _mm256_setzero_ps ();
624+
625+ for (int ib = 0 ; ib < nb ; ++ ib ) {
626+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
627+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
628+ __m256 acc_block = _mm256_setzero_ps ();
629+ #define Q1_AVX_BLOCK (K ) \
630+ { \
631+ const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); \
632+ const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); \
633+ const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); \
634+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \
635+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \
636+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
637+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
638+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
639+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
640+ const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); \
641+ const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); \
642+ const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); \
643+ const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); \
644+ const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); \
645+ acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); \
646+ }
647+ Q1_AVX_BLOCK (0 )
648+ Q1_AVX_BLOCK (1 )
649+ Q1_AVX_BLOCK (2 )
650+ Q1_AVX_BLOCK (3 )
651+ #undef Q1_AVX_BLOCK
652+
653+ acc = _mm256_add_ps (acc , _mm256_mul_ps (_mm256_set1_ps (d0 ), acc_block ));
654+ }
655+
656+ * s = hsum_float_8 (acc );
657+ #elif defined(__SSSE3__ )
658+ const __m128i ones_8 = _mm_set1_epi8 (1 );
659+ const __m128i ones_16 = _mm_set1_epi16 (1 );
660+ const __m128i zero = _mm_setzero_si128 ();
661+ __m128 acc_0 = _mm_setzero_ps ();
662+ __m128 acc_1 = _mm_setzero_ps ();
663+ __m128 acc_2 = _mm_setzero_ps ();
664+ __m128 acc_3 = _mm_setzero_ps ();
665+
666+ for (int ib = 0 ; ib < nb ; ++ ib ) {
667+ const __m128 d0 = _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x [ib ].d ));
668+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
669+
670+ #define Q1_SSSE3_BLOCK (QS_OFF , Y_IDX , ACC ) \
671+ { \
672+ const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \
673+ const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \
674+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \
675+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \
676+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
677+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
678+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
679+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
680+ const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \
681+ const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \
682+ const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \
683+ (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \
684+ }
685+ Q1_SSSE3_BLOCK (0 , 0 , acc_0 )
686+ Q1_SSSE3_BLOCK (4 , 1 , acc_1 )
687+ Q1_SSSE3_BLOCK (8 , 2 , acc_2 )
688+ Q1_SSSE3_BLOCK (12 , 3 , acc_3 )
689+ #undef Q1_SSSE3_BLOCK
690+ }
691+
692+ * s = hsum_float_4x4 (acc_0 , acc_1 , acc_2 , acc_3 );
693+ #else
694+ float sumf = 0.0f ;
695+
696+ for (int ib = 0 ; ib < nb ; ++ ib ) {
697+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
698+ float sumi = 0.0f ;
699+
700+ for (int k = 0 ; k < 4 ; k ++ ) {
701+ const block_q8_0 * GGML_RESTRICT yb = & y [ib * 4 + k ];
702+ const float d1 = GGML_CPU_FP16_TO_FP32 (yb -> d );
703+ int sumi_block = 0 ;
704+
705+ const uint8_t * GGML_RESTRICT bits = & x [ib ].qs [k * 4 ];
706+ const int8_t * GGML_RESTRICT qy = yb -> qs ;
707+
708+ for (int b = 0 ; b < 4 ; ++ b , qy += 8 ) {
709+ const unsigned mask = bits [b ];
710+ sumi_block += ((mask & 0x01 ) ? qy [0 ] : - qy [0 ])
711+ + ((mask & 0x02 ) ? qy [1 ] : - qy [1 ])
712+ + ((mask & 0x04 ) ? qy [2 ] : - qy [2 ])
713+ + ((mask & 0x08 ) ? qy [3 ] : - qy [3 ])
714+ + ((mask & 0x10 ) ? qy [4 ] : - qy [4 ])
715+ + ((mask & 0x20 ) ? qy [5 ] : - qy [5 ])
716+ + ((mask & 0x40 ) ? qy [6 ] : - qy [6 ])
717+ + ((mask & 0x80 ) ? qy [7 ] : - qy [7 ]);
718+ }
719+
720+ sumi += d1 * sumi_block ;
721+ }
722+
723+ sumf += d0 * sumi ;
724+ }
725+
726+ * s = sumf ;
727+ #endif
728+ }
729+
543730void ggml_vec_dot_q4_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
544731 const int qk = QK8_0 ;
545732 const int nb = n / qk ;
0 commit comments