@@ -137,6 +137,109 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
137137
138138//===================================== Dot products =================================
139139
140+ void ggml_vec_dot_q1_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
141+ const int qk = QK1_0 ; // 128
142+ const int nb = n / qk ;
143+
144+ assert (n % qk == 0 );
145+ assert (nrc == 1 );
146+ UNUSED (nrc );
147+ UNUSED (bx );
148+ UNUSED (by );
149+ UNUSED (bs );
150+
151+ const block_q1_0 * GGML_RESTRICT x = vx ;
152+ const block_q8_0 * GGML_RESTRICT y = vy ;
153+
154+ float sumf = 0.0f ;
155+
156+ #if defined(__ARM_NEON )
157+ float32x4_t sumv = vdupq_n_f32 (0.0f );
158+
159+ for (int i = 0 ; i < nb ; i ++ ) {
160+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [i ].d );
161+
162+ // Process 4 Q8_0 blocks (each has 32 elements)
163+ for (int k = 0 ; k < 4 ; k ++ ) {
164+ const block_q8_0 * GGML_RESTRICT yb = & y [i * 4 + k ];
165+ const float d1 = GGML_CPU_FP16_TO_FP32 (yb -> d );
166+
167+ // Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes)
168+ // Bits are at offset k*4 bytes in x[i].qs
169+ const uint8_t * bits = & x [i ].qs [k * 4 ];
170+
171+ // Load 32 int8 values from y
172+ const int8x16_t y0 = vld1q_s8 (yb -> qs );
173+ const int8x16_t y1 = vld1q_s8 (yb -> qs + 16 );
174+
175+ // Byte 0-1: bits for y0[0..15]
176+ const uint64_t expand0 = table_b2b_0 [bits [0 ]];
177+ const uint64_t expand1 = table_b2b_0 [bits [1 ]];
178+ // Byte 2-3: bits for y1[0..15]
179+ const uint64_t expand2 = table_b2b_0 [bits [2 ]];
180+ const uint64_t expand3 = table_b2b_0 [bits [3 ]];
181+
182+ // Build the sign vectors by reinterpreting the table values
183+ uint8x8_t e0 = vcreate_u8 (expand0 );
184+ uint8x8_t e1 = vcreate_u8 (expand1 );
185+ uint8x8_t e2 = vcreate_u8 (expand2 );
186+ uint8x8_t e3 = vcreate_u8 (expand3 );
187+
188+ // Shift right by 4 to get 0 or 1
189+ int8x8_t s0 = vreinterpret_s8_u8 (vshr_n_u8 (e0 , 4 ));
190+ int8x8_t s1 = vreinterpret_s8_u8 (vshr_n_u8 (e1 , 4 ));
191+ int8x8_t s2 = vreinterpret_s8_u8 (vshr_n_u8 (e2 , 4 ));
192+ int8x8_t s3 = vreinterpret_s8_u8 (vshr_n_u8 (e3 , 4 ));
193+
194+ // Convert 0/1 to -1/+1: sign = 2*val - 1
195+ int8x8_t one = vdup_n_s8 (1 );
196+ s0 = vsub_s8 (vadd_s8 (s0 , s0 ), one ); // 2*s0 - 1
197+ s1 = vsub_s8 (vadd_s8 (s1 , s1 ), one );
198+ s2 = vsub_s8 (vadd_s8 (s2 , s2 ), one );
199+ s3 = vsub_s8 (vadd_s8 (s3 , s3 ), one );
200+
201+ // Combine into 16-element vectors
202+ int8x16_t signs0 = vcombine_s8 (s0 , s1 );
203+ int8x16_t signs1 = vcombine_s8 (s2 , s3 );
204+
205+ // Multiply signs with y values and accumulate
206+ // dot(signs, y) where signs are +1/-1
207+ int32x4_t p0 = ggml_vdotq_s32 (vdupq_n_s32 (0 ), signs0 , y0 );
208+ int32x4_t p1 = ggml_vdotq_s32 (p0 , signs1 , y1 );
209+
210+ // Scale by d1 and accumulate
211+ sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (p1 ), d0 * d1 );
212+ }
213+ }
214+
215+ sumf = vaddvq_f32 (sumv );
216+ #else
217+ // Scalar fallback
218+ for (int i = 0 ; i < nb ; i ++ ) {
219+ const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
220+
221+ // Process 4 Q8_0 blocks
222+ for (int k = 0 ; k < 4 ; k ++ ) {
223+ const float d1 = GGML_FP16_TO_FP32 (y [i * 4 + k ].d );
224+
225+ int sumi = 0 ;
226+ for (int j = 0 ; j < QK8_0 ; j ++ ) {
227+ const int bit_index = k * QK8_0 + j ;
228+ const int byte_index = bit_index / 8 ;
229+ const int bit_offset = bit_index % 8 ;
230+
231+ const int xi = ((x [i ].qs [byte_index ] >> bit_offset ) & 1 ) ? 1 : -1 ;
232+ sumi += xi * y [i * 4 + k ].qs [j ];
233+ }
234+ sumf += d0 * d1 * sumi ;
235+ }
236+ }
237+ #endif
238+
239+ * s = sumf ;
240+ }
241+
242+
140243void ggml_vec_dot_q4_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
141244 const int qk = QK8_0 ;
142245 const int nb = n / qk ;
0 commit comments