@@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
783783 const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8 (values , vandq_u8 (q4bits_1 , m4b ));
784784 const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8 (values , vshrq_n_u8 (q4bits_1 , 4 ));
785785
786+ #if defined(__ARM_FEATURE_DOTPROD )
786787 const int8x16_t q8_0a = vld1q_s8 (y [2 * ib ].qs );
787788 const int8x16_t q8_0b = vld1q_s8 (y [2 * ib ].qs + 16 );
788789 const int8x16_t q8_lo_0 = vcombine_s8 (vget_low_s8 (q8_0a ), vget_low_s8 (q8_0b ));
@@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
794795 const int8x16_t q8_hi_1 = vcombine_s8 (vget_high_s8 (q8_1a ), vget_high_s8 (q8_1b ));
795796
796797 const int32x4_t p0 = vaddq_s32 (
797- ggml_vdotq_s32 (vdupq_n_s32 (0 ), q4_lo_0 , q8_lo_0 ),
798- ggml_vdotq_s32 (vdupq_n_s32 (0 ), q4_hi_0 , q8_hi_0 ));
798+ vdotq_s32 (vdupq_n_s32 (0 ), q4_lo_0 , q8_lo_0 ),
799+ vdotq_s32 (vdupq_n_s32 (0 ), q4_hi_0 , q8_hi_0 ));
799800 const int32x4_t p1 = vaddq_s32 (
800- ggml_vdotq_s32 (vdupq_n_s32 (0 ), q4_lo_1 , q8_lo_1 ),
801- ggml_vdotq_s32 (vdupq_n_s32 (0 ), q4_hi_1 , q8_hi_1 ));
801+ vdotq_s32 (vdupq_n_s32 (0 ), q4_lo_1 , q8_lo_1 ),
802+ vdotq_s32 (vdupq_n_s32 (0 ), q4_hi_1 , q8_hi_1 ));
802803
803- const int32x4_t sums = vpaddq_s32 (p0 , p1 );
804+ const int32x4_t sumi = vpaddq_s32 (p0 , p1 );
805+ #else
806+ const int8x8_t q4_0_lo = vget_low_s8 (q4_lo_0 );
807+ const int8x8_t q4_0_hi = vget_low_s8 (q4_hi_0 );
808+ const int8x8_t q4_1_lo = vget_high_s8 (q4_lo_0 );
809+ const int8x8_t q4_1_hi = vget_high_s8 (q4_hi_0 );
810+ const int8x8_t q4_2_lo = vget_low_s8 (q4_lo_1 );
811+ const int8x8_t q4_2_hi = vget_low_s8 (q4_hi_1 );
812+ const int8x8_t q4_3_lo = vget_high_s8 (q4_lo_1 );
813+ const int8x8_t q4_3_hi = vget_high_s8 (q4_hi_1 );
814+
815+ const int8x8_t q8_0_lo = vld1_s8 (y [2 * ib ].qs );
816+ const int8x8_t q8_0_hi = vld1_s8 (y [2 * ib ].qs + 8 );
817+ const int8x8_t q8_1_lo = vld1_s8 (y [2 * ib ].qs + 16 );
818+ const int8x8_t q8_1_hi = vld1_s8 (y [2 * ib ].qs + 24 );
819+ const int8x8_t q8_2_lo = vld1_s8 (y [2 * ib + 1 ].qs );
820+ const int8x8_t q8_2_hi = vld1_s8 (y [2 * ib + 1 ].qs + 8 );
821+ const int8x8_t q8_3_lo = vld1_s8 (y [2 * ib + 1 ].qs + 16 );
822+ const int8x8_t q8_3_hi = vld1_s8 (y [2 * ib + 1 ].qs + 24 );
823+
824+ const int32x4_t sumi = (int32x4_t ){
825+ vaddvq_s32 (ggml_nvfp4_dot8 (q4_0_lo , q8_0_lo , q4_0_hi , q8_0_hi )),
826+ vaddvq_s32 (ggml_nvfp4_dot8 (q4_1_lo , q8_1_lo , q4_1_hi , q8_1_hi )),
827+ vaddvq_s32 (ggml_nvfp4_dot8 (q4_2_lo , q8_2_lo , q4_2_hi , q8_2_hi )),
828+ vaddvq_s32 (ggml_nvfp4_dot8 (q4_3_lo , q8_3_lo , q4_3_hi , q8_3_hi )),
829+ };
830+ #endif
804831
805- // Decode 4 UE4M3 scales to f32 and multiply with q8 scales
806832 const float dy0 = GGML_CPU_FP16_TO_FP32 (y [2 * ib ].d );
807833 const float dy1 = GGML_CPU_FP16_TO_FP32 (y [2 * ib + 1 ].d );
808834 const float32x4_t nvsc = {
@@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
813839 };
814840 const float32x4_t scales = vmulq_f32 (nvsc , (float32x4_t ){dy0 , dy0 , dy1 , dy1 });
815841
816- acc = vfmaq_f32 (acc , vcvtq_f32_s32 (sums ), scales );
842+ acc = vfmaq_f32 (acc , vcvtq_f32_s32 (sumi ), scales );
817843 }
818844 sumf = vaddvq_f32 (acc );
819845#else
0 commit comments