@@ -221,6 +221,24 @@ static void ggml_vec_dot_tq3_1s_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
221221static void ggml_vec_dot_tq4_1s_q8_0 (int n , float * GGML_RESTRICT s , size_t bs ,
222222 const void * GGML_RESTRICT vx , size_t bx ,
223223 const void * GGML_RESTRICT vy , size_t by , int nrc );
224+ static void ggml_vec_dot_sq2_0_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
225+ const void * GGML_RESTRICT vx , size_t bx ,
226+ const void * GGML_RESTRICT vy , size_t by , int nrc );
227+ static void ggml_vec_dot_sq3_1s_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
228+ const void * GGML_RESTRICT vx , size_t bx ,
229+ const void * GGML_RESTRICT vy , size_t by , int nrc );
230+ static void ggml_vec_dot_sq4_1s_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
231+ const void * GGML_RESTRICT vx , size_t bx ,
232+ const void * GGML_RESTRICT vy , size_t by , int nrc );
233+ void ggml_vec_dot_skv2_0_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
234+ const void * GGML_RESTRICT vx , size_t bx ,
235+ const void * GGML_RESTRICT vy , size_t by , int nrc );
236+ void ggml_vec_dot_skv3_0_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
237+ const void * GGML_RESTRICT vx , size_t bx ,
238+ const void * GGML_RESTRICT vy , size_t by , int nrc );
239+ void ggml_vec_dot_skv4_0_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
240+ const void * GGML_RESTRICT vx , size_t bx ,
241+ const void * GGML_RESTRICT vy , size_t by , int nrc );
224242
225243static const struct ggml_type_traits_cpu type_traits_cpu [GGML_TYPE_COUNT ] = {
226244 [GGML_TYPE_F32 ] = {
@@ -441,6 +459,42 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
441459 .vec_dot_type = GGML_TYPE_Q8_0 ,
442460 .nrows = 1 ,
443461 },
462+ [GGML_TYPE_SQ2_0 ] = {
463+ .from_float = NULL ,
464+ .vec_dot = (ggml_vec_dot_t ) ggml_vec_dot_sq2_0_f32 ,
465+ .vec_dot_type = GGML_TYPE_F32 ,
466+ .nrows = 1 ,
467+ },
468+ [GGML_TYPE_SQ3_1S ] = {
469+ .from_float = NULL ,
470+ .vec_dot = (ggml_vec_dot_t ) ggml_vec_dot_sq3_1s_f32 ,
471+ .vec_dot_type = GGML_TYPE_F32 ,
472+ .nrows = 1 ,
473+ },
474+ [GGML_TYPE_SQ4_1S ] = {
475+ .from_float = NULL ,
476+ .vec_dot = (ggml_vec_dot_t ) ggml_vec_dot_sq4_1s_f32 ,
477+ .vec_dot_type = GGML_TYPE_F32 ,
478+ .nrows = 1 ,
479+ },
480+ [GGML_TYPE_SKV2_0 ] = {
481+ .from_float = (ggml_from_float_t ) quantize_row_skv2_0_ref ,
482+ .vec_dot = (ggml_vec_dot_t ) ggml_vec_dot_skv2_0_f32 ,
483+ .vec_dot_type = GGML_TYPE_F32 ,
484+ .nrows = 1 ,
485+ },
486+ [GGML_TYPE_SKV3_0 ] = {
487+ .from_float = (ggml_from_float_t ) quantize_row_skv3_0_ref ,
488+ .vec_dot = (ggml_vec_dot_t ) ggml_vec_dot_skv3_0_f32 ,
489+ .vec_dot_type = GGML_TYPE_F32 ,
490+ .nrows = 1 ,
491+ },
492+ [GGML_TYPE_SKV4_0 ] = {
493+ .from_float = (ggml_from_float_t ) quantize_row_skv4_0_ref ,
494+ .vec_dot = (ggml_vec_dot_t ) ggml_vec_dot_skv4_0_f32 ,
495+ .vec_dot_type = GGML_TYPE_F32 ,
496+ .nrows = 1 ,
497+ },
444498};
445499
446500const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu (enum ggml_type type ) {
@@ -1296,6 +1350,10 @@ void ggml_compute_forward_mul_mat(
12961350 ggml_from_float_t const from_float = type_traits_cpu [vec_dot_type ].from_float ;
12971351 int64_t const vec_dot_num_rows = type_traits_cpu [src0 -> type ].nrows ;
12981352
1353+ if (ggml_is_spectral_weight_type (src0 -> type )) {
1354+ ggml_sq_vec_cache_reset ();
1355+ }
1356+
12991357 GGML_ASSERT (ne0 == ne01 );
13001358 GGML_ASSERT (ne1 == ne11 );
13011359 GGML_ASSERT (ne2 == ne12 );
@@ -1574,6 +1632,10 @@ static void ggml_compute_forward_mul_mat_id(
15741632 enum ggml_type const vec_dot_type = type_traits_cpu [type ].vec_dot_type ;
15751633 ggml_from_float_t const from_float = type_traits_cpu [vec_dot_type ].from_float ;
15761634
1635+ if (ggml_is_spectral_weight_type (type )) {
1636+ ggml_sq_vec_cache_reset ();
1637+ }
1638+
15771639 // we don't support permuted src0 or src1
15781640 GGML_ASSERT (nb00 == ggml_type_size (type ));
15791641 GGML_ASSERT (nb10 == ggml_type_size (src1 -> type ));
@@ -3483,6 +3545,75 @@ static void ggml_vec_dot_tq4_1s_q8_0(int n, float * GGML_RESTRICT s, size_t bs,
34833545 * s = sum ;
34843546}
34853547
3548+ static void ggml_vec_dot_sq2_0_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
3549+ const void * GGML_RESTRICT vx , size_t bx ,
3550+ const void * GGML_RESTRICT vy , size_t by , int nrc ) {
3551+ GGML_ASSERT (nrc == 1 );
3552+ GGML_UNUSED (bs ); GGML_UNUSED (bx ); GGML_UNUSED (by ); GGML_UNUSED (nrc );
3553+
3554+ if (ggml_sq_vec_dot_f32 (GGML_TYPE_SQ2_0 , vx , (const float * ) vy , n , s )) {
3555+ return ;
3556+ }
3557+
3558+ float * tmp = (float * ) malloc ((size_t ) n * sizeof (float ));
3559+ GGML_ASSERT (tmp != NULL );
3560+ ggml_get_type_traits (GGML_TYPE_SQ2_0 )-> to_float (vx , tmp , n );
3561+
3562+ const float * y = (const float * ) vy ;
3563+ float sum = 0.0f ;
3564+ for (int i = 0 ; i < n ; ++ i ) {
3565+ sum += tmp [i ] * y [i ];
3566+ }
3567+ free (tmp );
3568+ * s = sum ;
3569+ }
3570+
3571+ static void ggml_vec_dot_sq3_1s_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
3572+ const void * GGML_RESTRICT vx , size_t bx ,
3573+ const void * GGML_RESTRICT vy , size_t by , int nrc ) {
3574+ GGML_ASSERT (nrc == 1 );
3575+ GGML_UNUSED (bs ); GGML_UNUSED (bx ); GGML_UNUSED (by ); GGML_UNUSED (nrc );
3576+
3577+ if (ggml_sq_vec_dot_f32 (GGML_TYPE_SQ3_1S , vx , (const float * ) vy , n , s )) {
3578+ return ;
3579+ }
3580+
3581+ float * tmp = (float * ) malloc ((size_t ) n * sizeof (float ));
3582+ GGML_ASSERT (tmp != NULL );
3583+ ggml_get_type_traits (GGML_TYPE_SQ3_1S )-> to_float (vx , tmp , n );
3584+
3585+ const float * y = (const float * ) vy ;
3586+ float sum = 0.0f ;
3587+ for (int i = 0 ; i < n ; ++ i ) {
3588+ sum += tmp [i ] * y [i ];
3589+ }
3590+ free (tmp );
3591+ * s = sum ;
3592+ }
3593+
3594+ static void ggml_vec_dot_sq4_1s_f32 (int n , float * GGML_RESTRICT s , size_t bs ,
3595+ const void * GGML_RESTRICT vx , size_t bx ,
3596+ const void * GGML_RESTRICT vy , size_t by , int nrc ) {
3597+ GGML_ASSERT (nrc == 1 );
3598+ GGML_UNUSED (bs ); GGML_UNUSED (bx ); GGML_UNUSED (by ); GGML_UNUSED (nrc );
3599+
3600+ if (ggml_sq_vec_dot_f32 (GGML_TYPE_SQ4_1S , vx , (const float * ) vy , n , s )) {
3601+ return ;
3602+ }
3603+
3604+ float * tmp = (float * ) malloc ((size_t ) n * sizeof (float ));
3605+ GGML_ASSERT (tmp != NULL );
3606+ ggml_get_type_traits (GGML_TYPE_SQ4_1S )-> to_float (vx , tmp , n );
3607+
3608+ const float * y = (const float * ) vy ;
3609+ float sum = 0.0f ;
3610+ for (int i = 0 ; i < n ; ++ i ) {
3611+ sum += tmp [i ] * y [i ];
3612+ }
3613+ free (tmp );
3614+ * s = sum ;
3615+ }
3616+
34863617void ggml_cpu_fp32_to_fp32 (const float * x , float * y , int64_t n ) {
34873618 memcpy (y , x , n * sizeof (float ));
34883619}
0 commit comments