@@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const
274274}
275275#endif
276276#elif defined(__SSSE3__ )
277+ static inline __m128i bytes_from_bits_16 (const uint8_t * x ) {
278+ uint16_t x16 ;
279+ memcpy (& x16 , x , sizeof (uint16_t ));
280+
281+ const __m128i shuf_mask = _mm_set_epi64x (0x0101010101010101 , 0x0000000000000000 );
282+ __m128i bytes = _mm_shuffle_epi8 (_mm_set1_epi16 ((short ) x16 ), shuf_mask );
283+ const __m128i bit_mask = _mm_set_epi64x (0x7fbfdfeff7fbfdfe , 0x7fbfdfeff7fbfdfe );
284+ bytes = _mm_or_si128 (bytes , bit_mask );
285+
286+ return _mm_cmpeq_epi8 (bytes , _mm_set1_epi64x (-1 ));
287+ }
288+
277289// horizontally add 4x4 floats
278290static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
279291 __m128 res_0 = _mm_hadd_ps (a , b );
@@ -540,6 +552,161 @@ static inline __m128i get_scale_shuffle(int i) {
540552}
541553#endif
542554
555+ void ggml_vec_dot_q1_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
556+ const int qk = QK1_0 ;
557+ const int nb = n / qk ;
558+
559+ assert (n % qk == 0 );
560+ assert (nrc == 1 );
561+ UNUSED (nrc );
562+ UNUSED (bx );
563+ UNUSED (by );
564+ UNUSED (bs );
565+
566+ const block_q1_0 * GGML_RESTRICT x = vx ;
567+ const block_q8_0 * GGML_RESTRICT y = vy ;
568+
569+ #if defined(__AVX2__ )
570+ const __m256i ones_8 = _mm256_set1_epi8 (1 );
571+ const __m256i ones_16 = _mm256_set1_epi16 (1 );
572+ const __m256i byte_shuf = _mm256_setr_epi8 (
573+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
574+ 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 );
575+ const __m256i bit_masks = _mm256_setr_epi8 (
576+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 ,
577+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 );
578+ const __m256i zero = _mm256_setzero_si256 ();
579+ __m256 acc = _mm256_setzero_ps ();
580+
581+ for (int ib = 0 ; ib < nb ; ++ ib ) {
582+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
583+ const uint32_t * GGML_RESTRICT qs32 = (const uint32_t * ) x [ib ].qs ;
584+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
585+
586+ __m256 acc_block ;
587+ {
588+ const __m256i qy = _mm256_loadu_si256 ((const __m256i * ) y_ptr [0 ].qs );
589+ const __m256i sm = _mm256_cmpeq_epi8 (
590+ _mm256_and_si256 (_mm256_shuffle_epi8 (_mm256_set1_epi32 ((int ) qs32 [0 ]), byte_shuf ), bit_masks ), zero );
591+ const __m256i sy = _mm256_sub_epi8 (_mm256_xor_si256 (qy , sm ), sm );
592+ const __m256i s32 = _mm256_madd_epi16 (_mm256_maddubs_epi16 (ones_8 , sy ), ones_16 );
593+ acc_block = _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [0 ].d )), _mm256_cvtepi32_ps (s32 ));
594+ }
595+ #define Q1_AVX2_BLOCK (K ) \
596+ { \
597+ const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \
598+ const __m256i sm = _mm256_cmpeq_epi8( \
599+ _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); \
600+ const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); \
601+ const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \
602+ acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \
603+ }
604+ Q1_AVX2_BLOCK (1 )
605+ Q1_AVX2_BLOCK (2 )
606+ Q1_AVX2_BLOCK (3 )
607+ #undef Q1_AVX2_BLOCK
608+ acc = _mm256_fmadd_ps (_mm256_set1_ps (d0 ), acc_block , acc );
609+ }
610+
611+ * s = hsum_float_8 (acc );
612+ #elif defined(__AVX__ )
613+ const __m128i ones_8 = _mm_set1_epi8 (1 );
614+ const __m128i ones_16 = _mm_set1_epi16 (1 );
615+ const __m128i zero = _mm_setzero_si128 ();
616+ __m256 acc = _mm256_setzero_ps ();
617+
618+ for (int ib = 0 ; ib < nb ; ++ ib ) {
619+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
620+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
621+ __m256 acc_block ;
622+ #define Q1_AVX_BLOCK (K ) \
623+ { \
624+ const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); \
625+ const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); \
626+ const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); \
627+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \
628+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \
629+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
630+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
631+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
632+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
633+ const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); \
634+ const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); \
635+ const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); \
636+ const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); \
637+ const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); \
638+ acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); \
639+ }
640+ {
641+ const __m256i bit_mask = bytes_from_bits_32 (& x [ib ].qs [0 ]);
642+ const __m128i bit_mask_0 = _mm256_castsi256_si128 (bit_mask );
643+ const __m128i bit_mask_1 = _mm256_extractf128_si256 (bit_mask , 1 );
644+ const __m128i qy_0 = _mm_loadu_si128 ((const __m128i * ) & y_ptr [0 ].qs [0 ]);
645+ const __m128i qy_1 = _mm_loadu_si128 ((const __m128i * ) & y_ptr [0 ].qs [16 ]);
646+ const __m128i sign_mask_0 = _mm_cmpeq_epi8 (bit_mask_0 , zero );
647+ const __m128i sign_mask_1 = _mm_cmpeq_epi8 (bit_mask_1 , zero );
648+ const __m128i sy_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_0 , sign_mask_0 ), sign_mask_0 );
649+ const __m128i sy_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_1 , sign_mask_1 ), sign_mask_1 );
650+ const __m128i sum16_0 = _mm_maddubs_epi16 (ones_8 , sy_0 );
651+ const __m128i sum16_1 = _mm_maddubs_epi16 (ones_8 , sy_1 );
652+ const __m128i sum32_0 = _mm_madd_epi16 (sum16_0 , ones_16 );
653+ const __m128i sum32_1 = _mm_madd_epi16 (sum16_1 , ones_16 );
654+ const __m256 q = _mm256_cvtepi32_ps (MM256_SET_M128I (sum32_1 , sum32_0 ));
655+ acc_block = _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [0 ].d )), q );
656+ }
657+ Q1_AVX_BLOCK (1 )
658+ Q1_AVX_BLOCK (2 )
659+ Q1_AVX_BLOCK (3 )
660+ #undef Q1_AVX_BLOCK
661+
662+ acc = _mm256_add_ps (acc , _mm256_mul_ps (_mm256_set1_ps (d0 ), acc_block ));
663+ }
664+
665+ * s = hsum_float_8 (acc );
666+ #elif defined(__SSSE3__ )
667+ const __m128i ones_8 = _mm_set1_epi8 (1 );
668+ const __m128i ones_16 = _mm_set1_epi16 (1 );
669+ const __m128i zero = _mm_setzero_si128 ();
670+ __m128 acc_0 = _mm_setzero_ps ();
671+ __m128 acc_1 = _mm_setzero_ps ();
672+ __m128 acc_2 = _mm_setzero_ps ();
673+ __m128 acc_3 = _mm_setzero_ps ();
674+
675+ for (int ib = 0 ; ib < nb ; ++ ib ) {
676+ const __m128 d0 = _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x [ib ].d ));
677+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
678+
679+ #define Q1_SSSE3_BLOCK (QS_OFF , Y_IDX , ACC ) \
680+ { \
681+ const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \
682+ const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \
683+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \
684+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \
685+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
686+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
687+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
688+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
689+ const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \
690+ const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \
691+ const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \
692+ (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \
693+ }
694+ Q1_SSSE3_BLOCK (0 , 0 , acc_0 )
695+ Q1_SSSE3_BLOCK (4 , 1 , acc_1 )
696+ Q1_SSSE3_BLOCK (8 , 2 , acc_2 )
697+ Q1_SSSE3_BLOCK (12 , 3 , acc_3 )
698+ #undef Q1_SSSE3_BLOCK
699+ }
700+
701+ * s = hsum_float_4x4 (acc_0 , acc_1 , acc_2 , acc_3 );
702+ #else
703+ UNUSED (nb );
704+ UNUSED (x );
705+ UNUSED (y );
706+ ggml_vec_dot_q1_0_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
707+ #endif
708+ }
709+
543710void ggml_vec_dot_q4_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
544711 const int qk = QK8_0 ;
545712 const int nb = n / qk ;
0 commit comments