@@ -541,162 +541,11 @@ static inline __m128i get_scale_shuffle(int i) {
541541#endif
542542
543543void ggml_vec_dot_q1_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
544- const int qk = QK8_0 ;
545- const int nb = n / qk ;
546-
547- assert (n % qk == 0 );
548- assert (nrc == 1 );
549- UNUSED (nrc );
550- UNUSED (bx );
551- UNUSED (by );
552- UNUSED (bs );
553-
554- const block_q1_0 * GGML_RESTRICT x = vx ;
555- const block_q8_0 * GGML_RESTRICT y = vy ;
556-
557- int ib = 0 ;
558- float sumf = 0 ;
559-
560- #if defined(__AVX2__ )
561- // Initialize accumulator with zeros
562- __m256 acc = _mm256_setzero_ps ();
563-
564- // Main loop - compute dot product for each block
565- for (; ib < nb ; ++ ib ) {
566- // Compute combined scale for the block
567- const __m256 d = _mm256_set1_ps (GGML_CPU_FP16_TO_FP32 (x [ib ].d ) * GGML_CPU_FP16_TO_FP32 (y [ib ].d ));
568-
569- // Load Q1_0 bits (4 bytes = 32 bits)
570- const uint32_t qbits32 = * (const uint32_t * )x [ib ].qs ;
571-
572- // Load Q8_0 values (32 bytes)
573- const __m256i qy = _mm256_loadu_si256 ((const __m256i * )y [ib ].qs );
574-
575- // Expand 32 bits to 32 bytes (each bit becomes ±1)
576- // We need to place the right byte in each 8-byte group and mask the right bit
577- __m256i qx ;
578- {
579- // Create a vector with each of the 4 bytes replicated 8 times
580- // Byte 0 in positions 0-7, byte 1 in positions 8-15, byte 2 in positions 16-23, byte 3 in positions 24-31
581- const __m256i shuffle_mask = _mm256_set_epi8 (
582- 3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , // byte 3 (bits 24-31) replicated
583- 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , // byte 2 (bits 16-23) replicated
584- 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , // byte 1 (bits 8-15) replicated
585- 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 // byte 0 (bits 0-7) replicated
586- );
587-
588- // Broadcast the 4 bytes across the 128-bit lanes
589- const __m128i qbits_128 = _mm_set1_epi32 (qbits32 );
590- const __m256i qbits_256 = _mm256_broadcastsi128_si256 (qbits_128 );
591-
592- // Shuffle to replicate bytes
593- const __m256i qbits_shuffled = _mm256_shuffle_epi8 (qbits_256 , shuffle_mask );
594-
595- // Create bit masks for each position within a byte
596- const __m256i bit_mask = _mm256_set_epi8 (
597- (char )0x80 , 0x40 , 0x20 , 0x10 , 0x08 , 0x04 , 0x02 , 0x01 , // masks for byte 3
598- (char )0x80 , 0x40 , 0x20 , 0x10 , 0x08 , 0x04 , 0x02 , 0x01 , // masks for byte 2
599- (char )0x80 , 0x40 , 0x20 , 0x10 , 0x08 , 0x04 , 0x02 , 0x01 , // masks for byte 1
600- (char )0x80 , 0x40 , 0x20 , 0x10 , 0x08 , 0x04 , 0x02 , 0x01 // masks for byte 0
601- );
602-
603- // Test each bit: AND with mask, compare to mask
604- // Result is 0xFF if bit is set, 0x00 if not
605- const __m256i bit_test = _mm256_and_si256 (qbits_shuffled , bit_mask );
606- const __m256i is_set = _mm256_cmpeq_epi8 (bit_test , bit_mask );
607-
608- // Convert 0xFF -> +1, 0x00 -> -1
609- // is_set is 0xFF (all bits set) if bit is 1, or 0x00 if bit is 0
610- // We want: +1 if bit is 1, -1 if bit is 0
611- // Method: (is_set & 1) gives 1 or 0, then (value << 1) - 1 gives +1 or -1
612- const __m256i ones = _mm256_set1_epi8 (1 );
613- const __m256i bit_value = _mm256_and_si256 (is_set , ones ); // 0x01 or 0x00
614- const __m256i bit_doubled = _mm256_add_epi8 (bit_value , bit_value ); // 0x02 or 0x00
615- qx = _mm256_sub_epi8 (bit_doubled , ones ); // 0x01 or 0xFF (-1)
616- }
617-
618- // Multiply and accumulate using the same pattern as Q4_0
619- const __m256 q = mul_sum_i8_pairs_float (qx , qy );
620-
621- // Multiply q with scale and accumulate
622- acc = _mm256_fmadd_ps (d , q , acc );
623- }
624-
625- sumf = hsum_float_8 (acc );
626-
627- #endif
628- // Fallback scalar loop for remaining blocks
629- for (; ib < nb ; ++ ib ) {
630- const uint8_t * qbits = x [ib ].qs ;
631- int sumi = 0 ;
632-
633- // Optimized scalar processing for QK1_0 bits
634- for (int byte_idx = 0 ; byte_idx < QK1_0 /8 ; ++ byte_idx ) {
635- const uint8_t bits8 = qbits [byte_idx ];
636- const int base_idx = byte_idx * 8 ;
637-
638- // Process each bit
639- for (int bit_idx = 0 ; bit_idx < 8 ; ++ bit_idx ) {
640- const int xi = (bits8 & (1U << bit_idx )) ? 1 : -1 ;
641- sumi += xi * y [ib ].qs [base_idx + bit_idx ];
642- }
643- }
644-
645- sumf += sumi * GGML_CPU_FP16_TO_FP32 (x [ib ].d ) * GGML_CPU_FP16_TO_FP32 (y [ib ].d );
646- }
647-
648- * s = sumf ;
544+ ggml_vec_dot_q1_0_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
649545}
650546
651547void ggml_vec_dot_q1_0_g128_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
652- const int qk = QK1_0_g128 ;
653- const int nb = n / qk ;
654-
655- assert (n % qk == 0 );
656- assert (nrc == 1 );
657- UNUSED (nrc );
658- UNUSED (bx );
659- UNUSED (by );
660- UNUSED (bs );
661-
662- const block_q1_0_g128 * GGML_RESTRICT x = vx ;
663- const block_q8_0 * GGML_RESTRICT y = vy ;
664-
665- float sumf = 0 ;
666-
667- // Each Q1_0_g128 block has 128 elements
668- // Each Q8_0 block has 32 elements
669- // So we need 4 Q8_0 blocks per Q1_0_g128 block
670- for (int ib = 0 ; ib < nb ; ++ ib ) {
671- const float d0 = GGML_CPU_FP16_TO_FP32 (x [ib ].d );
672-
673- int sumi = 0 ;
674-
675- // Process 4 Q8_0 blocks (4 * 32 = 128 elements)
676- for (int k = 0 ; k < 4 ; k ++ ) {
677- const float d1 = GGML_CPU_FP16_TO_FP32 (y [ib * 4 + k ].d );
678-
679- int sumi_block = 0 ;
680-
681- for (int j = 0 ; j < QK8_0 ; j ++ ) {
682- const int bit_index = k * QK8_0 + j ;
683- const int byte_index = bit_index / 8 ;
684- const int bit_offset = bit_index % 8 ;
685-
686- // Extract bit: 1 = +1, 0 = -1
687- const int xi = ((x [ib ].qs [byte_index ] >> bit_offset ) & 1 ) ? 1 : -1 ;
688- const int yi = y [ib * 4 + k ].qs [j ];
689-
690- sumi_block += xi * yi ;
691- }
692-
693- sumi += d1 * sumi_block ;
694- }
695-
696- sumf += d0 * sumi ;
697- }
698-
699- * s = sumf ;
548+ ggml_vec_dot_q1_0_g128_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
700549}
701550
702551void ggml_vec_dot_q4_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
0 commit comments