@@ -219,6 +219,80 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
219219#endif
220220}
221221
222+ void ggml_vec_dot_q2_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
223+ const int qk = QK2_0 ;
224+ const int nb = n / qk ;
225+
226+ assert (n % qk == 0 );
227+ assert (nrc == 1 );
228+ UNUSED (nrc );
229+ UNUSED (bx );
230+ UNUSED (by );
231+ UNUSED (bs );
232+
233+ const block_q2_0 * GGML_RESTRICT x = vx ;
234+ const block_q8_0 * GGML_RESTRICT y = vy ;
235+
236+ float sumf = 0.0f ;
237+
238+ #if defined(__ARM_NEON )
239+ // Replicate pattern: each byte repeated 4 times
240+ static const uint8_t tbl_idx_lo [16 ] = {0 ,0 ,0 ,0 , 1 ,1 ,1 ,1 , 2 ,2 ,2 ,2 , 3 ,3 ,3 ,3 };
241+ static const uint8_t tbl_idx_hi [16 ] = {4 ,4 ,4 ,4 , 5 ,5 ,5 ,5 , 6 ,6 ,6 ,6 , 7 ,7 ,7 ,7 };
242+ // Right-shift amounts: 0,2,4,6 repeated for each group of 4
243+ static const int8_t shift_vals [16 ] = {0 ,-2 ,-4 ,-6 , 0 ,-2 ,-4 ,-6 , 0 ,-2 ,-4 ,-6 , 0 ,-2 ,-4 ,-6 };
244+
245+ const uint8x16_t idx_lo = vld1q_u8 (tbl_idx_lo );
246+ const uint8x16_t idx_hi = vld1q_u8 (tbl_idx_hi );
247+ const int8x16_t shifts = vld1q_s8 (shift_vals );
248+ const uint8x16_t mask2 = vdupq_n_u8 (0x03 );
249+ const int8x16_t one = vdupq_n_s8 (1 );
250+
251+ float32x4_t sumv = vdupq_n_f32 (0.0f );
252+
253+ for (int i = 0 ; i < nb ; i ++ ) {
254+ const float d0 = GGML_CPU_FP16_TO_FP32 (x [i ].d );
255+
256+ for (int k = 0 ; k < 4 ; k ++ ) {
257+ const block_q8_0 * GGML_RESTRICT yb = & y [i * 4 + k ];
258+ const float d1 = GGML_CPU_FP16_TO_FP32 (yb -> d );
259+
260+ // Load 8 bytes of packed 2-bit values
261+ const uint8x8_t raw = vld1_u8 (& x [i ].qs [k * 8 ]);
262+ const uint8x16_t raw16 = vcombine_u8 (raw , raw );
263+
264+ // First 16 elements: replicate bytes 0-3, shift, mask, subtract 1
265+ uint8x16_t bytes0 = vqtbl1q_u8 (raw16 , idx_lo );
266+ int8x16_t qv0 = vsubq_s8 (
267+ vreinterpretq_s8_u8 (vandq_u8 (vshlq_u8 (bytes0 , shifts ), mask2 )),
268+ one );
269+
270+ // Second 16 elements: replicate bytes 4-7, shift, mask, subtract 1
271+ uint8x16_t bytes1 = vqtbl1q_u8 (raw16 , idx_hi );
272+ int8x16_t qv1 = vsubq_s8 (
273+ vreinterpretq_s8_u8 (vandq_u8 (vshlq_u8 (bytes1 , shifts ), mask2 )),
274+ one );
275+
276+ // Load Q8_0 values and dot product
277+ const int8x16_t y0 = vld1q_s8 (yb -> qs );
278+ const int8x16_t y1 = vld1q_s8 (yb -> qs + 16 );
279+
280+ int32x4_t p0 = ggml_vdotq_s32 (vdupq_n_s32 (0 ), qv0 , y0 );
281+ int32x4_t p1 = ggml_vdotq_s32 (p0 , qv1 , y1 );
282+
283+ sumv = vmlaq_n_f32 (sumv , vcvtq_f32_s32 (p1 ), d0 * d1 );
284+ }
285+ }
286+
287+ sumf = vaddvq_f32 (sumv );
288+ #else
289+ ggml_vec_dot_q2_0_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
290+ return ;
291+ #endif
292+
293+ * s = sumf ;
294+ }
295+
222296
223297void ggml_vec_dot_q4_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
224298 const int qk = QK8_0 ;
0 commit comments