Skip to content

Commit 6df15a6

Browse files
committed
Merge branch 'pr-21636' into prism-new
2 parents a2504b3 + 7f82cf0 commit 6df15a6

File tree

3 files changed

+183
-11
lines changed

3 files changed

+183
-11
lines changed

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@
8383
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
8484
// quants.c
8585
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
86-
#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0
8786
// repack.cpp
8887
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
8988
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const
274274
}
275275
#endif
276276
#elif defined(__SSSE3__)
277+
static inline __m128i bytes_from_bits_16(const uint8_t * x) {
278+
uint16_t x16;
279+
memcpy(&x16, x, sizeof(uint16_t));
280+
281+
const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
282+
__m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask);
283+
const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe);
284+
bytes = _mm_or_si128(bytes, bit_mask);
285+
286+
return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1));
287+
}
288+
277289
// horizontally add 4x4 floats
278290
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
279291
__m128 res_0 =_mm_hadd_ps(a, b);
@@ -540,6 +552,161 @@ static inline __m128i get_scale_shuffle(int i) {
540552
}
541553
#endif
542554

555+
void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
556+
const int qk = QK1_0;
557+
const int nb = n / qk;
558+
559+
assert(n % qk == 0);
560+
assert(nrc == 1);
561+
UNUSED(nrc);
562+
UNUSED(bx);
563+
UNUSED(by);
564+
UNUSED(bs);
565+
566+
const block_q1_0 * GGML_RESTRICT x = vx;
567+
const block_q8_0 * GGML_RESTRICT y = vy;
568+
569+
#if defined(__AVX2__)
570+
const __m256i ones_8 = _mm256_set1_epi8(1);
571+
const __m256i ones_16 = _mm256_set1_epi16(1);
572+
const __m256i byte_shuf = _mm256_setr_epi8(
573+
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
574+
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3);
575+
const __m256i bit_masks = _mm256_setr_epi8(
576+
1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128,
577+
1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128);
578+
const __m256i zero = _mm256_setzero_si256();
579+
__m256 acc = _mm256_setzero_ps();
580+
581+
for (int ib = 0; ib < nb; ++ib) {
582+
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
583+
const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs;
584+
const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4];
585+
586+
__m256 acc_block;
587+
{
588+
const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs);
589+
const __m256i sm = _mm256_cmpeq_epi8(
590+
_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero);
591+
const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm);
592+
const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16);
593+
acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32));
594+
}
595+
#define Q1_AVX2_BLOCK(K) \
596+
{ \
597+
const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); \
598+
const __m256i sm = _mm256_cmpeq_epi8( \
599+
_mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); \
600+
const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); \
601+
const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); \
602+
acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); \
603+
}
604+
Q1_AVX2_BLOCK(1)
605+
Q1_AVX2_BLOCK(2)
606+
Q1_AVX2_BLOCK(3)
607+
#undef Q1_AVX2_BLOCK
608+
acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc);
609+
}
610+
611+
*s = hsum_float_8(acc);
612+
#elif defined(__AVX__)
613+
const __m128i ones_8 = _mm_set1_epi8(1);
614+
const __m128i ones_16 = _mm_set1_epi16(1);
615+
const __m128i zero = _mm_setzero_si128();
616+
__m256 acc = _mm256_setzero_ps();
617+
618+
for (int ib = 0; ib < nb; ++ib) {
619+
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
620+
const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4];
621+
__m256 acc_block;
622+
#define Q1_AVX_BLOCK(K) \
623+
{ \
624+
const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); \
625+
const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); \
626+
const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); \
627+
const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); \
628+
const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); \
629+
const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
630+
const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
631+
const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
632+
const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
633+
const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); \
634+
const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); \
635+
const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); \
636+
const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); \
637+
const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); \
638+
acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); \
639+
}
640+
{
641+
const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[0]);
642+
const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask);
643+
const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1);
644+
const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[0]);
645+
const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[16]);
646+
const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero);
647+
const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero);
648+
const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0);
649+
const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1);
650+
const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0);
651+
const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1);
652+
const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16);
653+
const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16);
654+
const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0));
655+
acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), q);
656+
}
657+
Q1_AVX_BLOCK(1)
658+
Q1_AVX_BLOCK(2)
659+
Q1_AVX_BLOCK(3)
660+
#undef Q1_AVX_BLOCK
661+
662+
acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block));
663+
}
664+
665+
*s = hsum_float_8(acc);
666+
#elif defined(__SSSE3__)
667+
const __m128i ones_8 = _mm_set1_epi8(1);
668+
const __m128i ones_16 = _mm_set1_epi16(1);
669+
const __m128i zero = _mm_setzero_si128();
670+
__m128 acc_0 = _mm_setzero_ps();
671+
__m128 acc_1 = _mm_setzero_ps();
672+
__m128 acc_2 = _mm_setzero_ps();
673+
__m128 acc_3 = _mm_setzero_ps();
674+
675+
for (int ib = 0; ib < nb; ++ib) {
676+
const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
677+
const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4];
678+
679+
#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \
680+
{ \
681+
const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \
682+
const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \
683+
const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \
684+
const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \
685+
const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \
686+
const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \
687+
const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \
688+
const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \
689+
const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \
690+
const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \
691+
const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \
692+
(ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \
693+
}
694+
Q1_SSSE3_BLOCK(0, 0, acc_0)
695+
Q1_SSSE3_BLOCK(4, 1, acc_1)
696+
Q1_SSSE3_BLOCK(8, 2, acc_2)
697+
Q1_SSSE3_BLOCK(12, 3, acc_3)
698+
#undef Q1_SSSE3_BLOCK
699+
}
700+
701+
*s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
702+
#else
703+
UNUSED(nb);
704+
UNUSED(x);
705+
UNUSED(y);
706+
ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
707+
#endif
708+
}
709+
543710
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
544711
const int qk = QK8_0;
545712
const int nb = n / qk;

ggml/src/ggml-cpu/quants.c

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,28 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
137137
float sumf = 0.0;
138138

139139
for (int i = 0; i < nb; i++) {
140-
const float d0 = GGML_FP16_TO_FP32(x[i].d);
140+
const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d);
141141

142142
float sumi = 0.0f;
143143

144144
for (int k = 0; k < 4; k++) {
145-
const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d);
146-
145+
const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k];
146+
const float d1 = GGML_CPU_FP16_TO_FP32(yb->d);
147147
int sumi_block = 0;
148148

149-
for (int j = 0; j < QK8_0; j++) {
150-
const int bit_index = k * QK8_0 + j;
151-
const int byte_index = bit_index / 8;
152-
const int bit_offset = bit_index % 8;
153-
154-
const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1;
155-
sumi_block += xi * y[i*4 + k].qs[j];
149+
const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4];
150+
const int8_t * GGML_RESTRICT qy = yb->qs;
151+
152+
for (int b = 0; b < 4; ++b, qy += 8) {
153+
const unsigned mask = bits[b];
154+
sumi_block += ((mask & 0x01) ? qy[0] : -qy[0])
155+
+ ((mask & 0x02) ? qy[1] : -qy[1])
156+
+ ((mask & 0x04) ? qy[2] : -qy[2])
157+
+ ((mask & 0x08) ? qy[3] : -qy[3])
158+
+ ((mask & 0x10) ? qy[4] : -qy[4])
159+
+ ((mask & 0x20) ? qy[5] : -qy[5])
160+
+ ((mask & 0x40) ? qy[6] : -qy[6])
161+
+ ((mask & 0x80) ? qy[7] : -qy[7]);
156162
}
157163

158164
sumi += d1 * sumi_block;

0 commit comments

Comments
 (0)