@@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const
274274}
275275#endif
276276#elif defined(__SSSE3__ )
277+ static inline __m128i bytes_from_bits_16 (const uint8_t * x ) {
278+ uint16_t x16 ;
279+ memcpy (& x16 , x , sizeof (uint16_t ));
280+
281+ const __m128i shuf_mask = _mm_set_epi64x (0x0101010101010101 , 0x0000000000000000 );
282+ __m128i bytes = _mm_shuffle_epi8 (_mm_set1_epi16 ((short ) x16 ), shuf_mask );
283+ const __m128i bit_mask = _mm_set_epi64x (0x7fbfdfeff7fbfdfe , 0x7fbfdfeff7fbfdfe );
284+ bytes = _mm_or_si128 (bytes , bit_mask );
285+
286+ return _mm_cmpeq_epi8 (bytes , _mm_set1_epi64x (-1 ));
287+ }
288+
277289// horizontally add 4x4 floats
278290static inline float hsum_float_4x4 (const __m128 a , const __m128 b , const __m128 c , const __m128 d ) {
279291 __m128 res_0 = _mm_hadd_ps (a , b );
@@ -540,6 +552,152 @@ static inline __m128i get_scale_shuffle(int i) {
540552}
541553#endif
542554
555+ void ggml_vec_dot_q1_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
556+ const int qk = QK1_0 ;
557+ const int nb = n / qk ;
558+
559+ assert (n % qk == 0 );
560+ assert (nrc == 1 );
561+ UNUSED (nrc );
562+ UNUSED (bx );
563+ UNUSED (by );
564+ UNUSED (bs );
565+
566+ const block_q1_0 * GGML_RESTRICT x = vx ;
567+ const block_q8_0 * GGML_RESTRICT y = vy ;
568+
569+ #if defined(__AVX2__ )
570+ const __m256i ones_8 = _mm256_set1_epi8 (1 );
571+ const __m256i ones_16 = _mm256_set1_epi16 (1 );
572+ const __m256i byte_shuf = _mm256_setr_epi8 (
573+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ,
574+ 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 );
575+ const __m256i bit_masks = _mm256_setr_epi8 (
576+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 ,
577+ 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , (char ) -128 );
578+ const __m256i zero = _mm256_setzero_si256 ();
579+ __m256 acc = _mm256_setzero_ps ();
580+
581+ for (int ib = 0 ; ib < nb ; ++ ib ) {
582+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
583+ const uint32_t * GGML_RESTRICT qs32 = (const uint32_t * ) x [ib ].qs ;
584+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
585+
586+ __m256 acc_block ;
587+ {
588+ const __m256i qy = _mm256_loadu_si256 ((const __m256i * ) y_ptr [0 ].qs );
589+ const __m256i sm = _mm256_cmpeq_epi8 (
590+ _mm256_and_si256 (_mm256_shuffle_epi8 (_mm256_set1_epi32 ((int ) qs32 [0 ]), byte_shuf ), bit_masks ), zero );
591+ const __m256i sy = _mm256_sub_epi8 (_mm256_xor_si256 (qy , sm ), sm );
592+ const __m256i s32 = _mm256_madd_epi16 (_mm256_maddubs_epi16 (ones_8 , sy ), ones_16 );
593+ acc_block = _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [0 ].d )), _mm256_cvtepi32_ps (s32 ));
594+ }
595+ for (int K = 1 ; K < 4 ; ++ K ) {
596+ const __m256i qy = _mm256_loadu_si256 ((const __m256i * ) y_ptr [K ].qs );
597+ const __m256i sm = _mm256_cmpeq_epi8 (
598+ _mm256_and_si256 (_mm256_shuffle_epi8 (_mm256_set1_epi32 ((int ) qs32 [K ]), byte_shuf ), bit_masks ), zero );
599+ const __m256i sy = _mm256_sub_epi8 (_mm256_xor_si256 (qy , sm ), sm );
600+ const __m256i s32 = _mm256_madd_epi16 (_mm256_maddubs_epi16 (ones_8 , sy ), ones_16 );
601+ acc_block = _mm256_fmadd_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [K ].d )), _mm256_cvtepi32_ps (s32 ), acc_block );
602+ }
603+ acc = _mm256_fmadd_ps (_mm256_set1_ps (d0 ), acc_block , acc );
604+ }
605+
606+ * s = hsum_float_8 (acc );
607+ #elif defined(__AVX__ )
608+ const __m128i ones_8 = _mm_set1_epi8 (1 );
609+ const __m128i ones_16 = _mm_set1_epi16 (1 );
610+ const __m128i zero = _mm_setzero_si128 ();
611+ __m256 acc = _mm256_setzero_ps ();
612+
613+ for (int ib = 0 ; ib < nb ; ++ ib ) {
614+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
615+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
616+ __m256 acc_block ;
617+ {
618+ const __m256i bit_mask = bytes_from_bits_32 (& x [ib ].qs [0 ]);
619+ const __m128i bit_mask_0 = _mm256_castsi256_si128 (bit_mask );
620+ const __m128i bit_mask_1 = _mm256_extractf128_si256 (bit_mask , 1 );
621+ const __m128i qy_0 = _mm_loadu_si128 ((const __m128i * ) & y_ptr [0 ].qs [0 ]);
622+ const __m128i qy_1 = _mm_loadu_si128 ((const __m128i * ) & y_ptr [0 ].qs [16 ]);
623+ const __m128i sign_mask_0 = _mm_cmpeq_epi8 (bit_mask_0 , zero );
624+ const __m128i sign_mask_1 = _mm_cmpeq_epi8 (bit_mask_1 , zero );
625+ const __m128i sy_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_0 , sign_mask_0 ), sign_mask_0 );
626+ const __m128i sy_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_1 , sign_mask_1 ), sign_mask_1 );
627+ const __m128i sum16_0 = _mm_maddubs_epi16 (ones_8 , sy_0 );
628+ const __m128i sum16_1 = _mm_maddubs_epi16 (ones_8 , sy_1 );
629+ const __m128i sum32_0 = _mm_madd_epi16 (sum16_0 , ones_16 );
630+ const __m128i sum32_1 = _mm_madd_epi16 (sum16_1 , ones_16 );
631+ const __m256 q = _mm256_cvtepi32_ps (MM256_SET_M128I (sum32_1 , sum32_0 ));
632+ acc_block = _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [0 ].d )), q );
633+ }
634+ for (int K = 1 ; K < 4 ; ++ K ) {
635+ const __m256i bit_mask = bytes_from_bits_32 (& x [ib ].qs [(K ) * 4 ]);
636+ const __m128i bit_mask_0 = _mm256_castsi256_si128 (bit_mask );
637+ const __m128i bit_mask_1 = _mm256_extractf128_si256 (bit_mask , 1 );
638+ const __m128i qy_0 = _mm_loadu_si128 ((const __m128i * ) & y_ptr [(K )].qs [0 ]);
639+ const __m128i qy_1 = _mm_loadu_si128 ((const __m128i * ) & y_ptr [(K )].qs [16 ]);
640+ const __m128i sign_mask_0 = _mm_cmpeq_epi8 (bit_mask_0 , zero );
641+ const __m128i sign_mask_1 = _mm_cmpeq_epi8 (bit_mask_1 , zero );
642+ const __m128i sy_0 = _mm_sub_epi8 (_mm_xor_si128 (qy_0 , sign_mask_0 ), sign_mask_0 );
643+ const __m128i sy_1 = _mm_sub_epi8 (_mm_xor_si128 (qy_1 , sign_mask_1 ), sign_mask_1 );
644+ const __m128i sum16_0 = _mm_maddubs_epi16 (ones_8 , sy_0 );
645+ const __m128i sum16_1 = _mm_maddubs_epi16 (ones_8 , sy_1 );
646+ const __m128i sum32_0 = _mm_madd_epi16 (sum16_0 , ones_16 );
647+ const __m128i sum32_1 = _mm_madd_epi16 (sum16_1 , ones_16 );
648+ const __m256 q = _mm256_cvtepi32_ps (MM256_SET_M128I (sum32_1 , sum32_0 ));
649+ acc_block = _mm256_add_ps (acc_block , _mm256_mul_ps (_mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (y_ptr [(K )].d )), q ));
650+ }
651+ #undef Q1_AVX_BLOCK
652+
653+ acc = _mm256_add_ps (acc , _mm256_mul_ps (_mm256_set1_ps (d0 ), acc_block ));
654+ }
655+
656+ * s = hsum_float_8 (acc );
657+ #elif defined(__SSSE3__ )
658+ const __m128i ones_8 = _mm_set1_epi8 (1 );
659+ const __m128i ones_16 = _mm_set1_epi16 (1 );
660+ const __m128i zero = _mm_setzero_si128 ();
661+ __m128 acc_0 = _mm_setzero_ps ();
662+ __m128 acc_1 = _mm_setzero_ps ();
663+ __m128 acc_2 = _mm_setzero_ps ();
664+ __m128 acc_3 = _mm_setzero_ps ();
665+
666+ for (int ib = 0 ; ib < nb ; ++ ib ) {
667+ const __m128 d0 = _mm_set1_ps (GGML_CPU_FP16_TO_FP32 (x [ib ].d ));
668+ const block_q8_0 * GGML_RESTRICT y_ptr = & y [ib * 4 ];
669+
670+ #define Q1_SSSE3_BLOCK (QS_OFF , Y_IDX , ACC ) \
671+ { \
672+ const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \
673+ const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \
674+ const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \
675+ const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \
676+ const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
677+ const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
678+ const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
679+ const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
680+ const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \
681+ const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \
682+ const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \
683+ (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \
684+ }
685+ Q1_SSSE3_BLOCK (0 , 0 , acc_0 )
686+ Q1_SSSE3_BLOCK (4 , 1 , acc_1 )
687+ Q1_SSSE3_BLOCK (8 , 2 , acc_2 )
688+ Q1_SSSE3_BLOCK (12 , 3 , acc_3 )
689+ #undef Q1_SSSE3_BLOCK
690+ }
691+
692+ * s = hsum_float_4x4 (acc_0 , acc_1 , acc_2 , acc_3 );
693+ #else
694+ UNUSED (nb );
695+ UNUSED (x );
696+ UNUSED (y );
697+ ggml_vec_dot_q1_0_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
698+ #endif
699+ }
700+
543701void ggml_vec_dot_q4_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
544702 const int qk = QK8_0 ;
545703 const int nb = n / qk ;
0 commit comments