Skip to content

Commit 2e05f06

Browse files
authored
ggml : fix ARM NEON nvfp4 dot product on non-dotprod targets (ggml-org#21559)
1 parent acc37a4 commit 2e05f06

2 files changed

Lines changed: 43 additions & 7 deletions

File tree

ggml/src/ggml-cpu/arch/arm/quants.c

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
783783
const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b));
784784
const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
785785

786+
#if defined(__ARM_FEATURE_DOTPROD)
786787
const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
787788
const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
788789
const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
@@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
794795
const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
795796

796797
const int32x4_t p0 = vaddq_s32(
797-
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
798-
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
798+
vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
799+
vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
799800
const int32x4_t p1 = vaddq_s32(
800-
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
801-
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
801+
vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
802+
vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
802803

803-
const int32x4_t sums = vpaddq_s32(p0, p1);
804+
const int32x4_t sumi = vpaddq_s32(p0, p1);
805+
#else
806+
const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0);
807+
const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0);
808+
const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0);
809+
const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0);
810+
const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1);
811+
const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1);
812+
const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1);
813+
const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1);
814+
815+
const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs);
816+
const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8);
817+
const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16);
818+
const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24);
819+
const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs);
820+
const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8);
821+
const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16);
822+
const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24);
823+
824+
const int32x4_t sumi = (int32x4_t){
825+
vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)),
826+
vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)),
827+
vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)),
828+
vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)),
829+
};
830+
#endif
804831

805-
// Decode 4 UE4M3 scales to f32 and multiply with q8 scales
806832
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
807833
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
808834
const float32x4_t nvsc = {
@@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
813839
};
814840
const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
815841

816-
acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
842+
acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales);
817843
}
818844
sumf = vaddvq_f32(acc);
819845
#else

ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
306306

307307
#if !defined(__ARM_FEATURE_DOTPROD)
308308

309+
// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter.
309310
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
310311
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
311312
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
@@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
319320

320321
#endif // !defined(__ARM_FEATURE_DOTPROD)
321322

323+
static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo,
324+
const int8x8_t q4_hi, const int8x8_t q8_hi) {
325+
const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo);
326+
const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi);
327+
const int32x4_t sum_lo = vpaddlq_s16(p_lo);
328+
const int32x4_t sum_hi = vpaddlq_s16(p_hi);
329+
return vaddq_s32(sum_lo, sum_hi);
330+
}
331+
322332
#endif // defined(__ARM_NEON)
323333

324334
#ifdef __wasm_simd128__

0 commit comments

Comments
 (0)