Skip to content

Commit 4c51309

Browse files
authored
ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 (ggml-org#22209)
* ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 Optimize the inner loop of ggml_vec_dot_q4_1_q8_1_generic using WASM SIMD128 intrinsics, gated behind #ifdef __wasm_simd128__ so non-wasm builds are completely unaffected. Approach: - single wasm_v128_load covers all 32 packed 4-bit weights - nibbles unpacked via AND/SHR into two u8x16 registers - widened to i16 before multiply (WASM SIMD has no i8*i8 instruction) - 4x wasm_i32x4_dot_i16x8 calls accumulate all 32 element pairs - horizontal reduce via 4x wasm_i32x4_extract_lane Benchmark (node v25, emcc -O3 -msimd128, 64 blocks x QK8_1=32, 200k iterations): | impl | ns/call | speedup | |--------|---------|---------| | scalar | 880.7 | 1.00x | | simd | 257.8 | 3.42x | Correctness verified against scalar reference across 10 random seeds with exact output match. * ggml: move q4_1_q8_1 WASM SIMD implementation to wasm backend Relocate the SIMD128 implementation of ggml_vec_dot_q4_1_q8_1 to ggml/src/ggml-cpu/arch/wasm/quants.c to follow architecture-specific layout. Restore the generic implementation in ggml/src/ggml-cpu/quants.c. Move for loop in the else block. * ggml: use generic q4_1_q8_1 fallback in wasm backend
1 parent 6f3a9f3 commit 4c51309

1 file changed

Lines changed: 72 additions & 0 deletions

File tree

ggml/src/ggml-cpu/arch/wasm/quants.c

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,78 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
355355
*s = sumf;
356356
}
357357

358+
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
359+
const int qk = QK8_1;
360+
const int nb = n / qk;
361+
362+
assert(n % qk == 0);
363+
assert(nrc == 1);
364+
UNUSED(nrc);
365+
UNUSED(bx);
366+
UNUSED(by);
367+
UNUSED(bs);
368+
369+
const block_q4_1 * GGML_RESTRICT x = vx;
370+
const block_q8_1 * GGML_RESTRICT y = vy;
371+
372+
float sumf = 0;
373+
374+
#if defined __wasm_simd128__
375+
v128_t sumv = wasm_f32x4_splat(0.0f);
376+
float summs = 0.0f;
377+
378+
for (int ib = 0; ib < nb; ++ib) {
379+
const block_q4_1 * GGML_RESTRICT x0 = &x[ib];
380+
const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
381+
382+
summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
383+
384+
const v128_t raw = wasm_v128_load(x0->qs);
385+
const v128_t v0s = wasm_v128_and(raw, wasm_i8x16_splat(0x0F));
386+
const v128_t v1s = wasm_u8x16_shr(raw, 4);
387+
388+
const v128_t ys_lo = wasm_v128_load(y0->qs);
389+
const v128_t ys_hi = wasm_v128_load(y0->qs + 16);
390+
391+
const v128_t v0s_l = wasm_u16x8_extend_low_u8x16(v0s);
392+
const v128_t v0s_h = wasm_u16x8_extend_high_u8x16(v0s);
393+
const v128_t ylo_l = wasm_i16x8_extend_low_i8x16(ys_lo);
394+
const v128_t ylo_h = wasm_i16x8_extend_high_i8x16(ys_lo);
395+
const v128_t v1s_l = wasm_u16x8_extend_low_u8x16(v1s);
396+
const v128_t v1s_h = wasm_u16x8_extend_high_u8x16(v1s);
397+
const v128_t yhi_l = wasm_i16x8_extend_low_i8x16(ys_hi);
398+
const v128_t yhi_h = wasm_i16x8_extend_high_i8x16(ys_hi);
399+
400+
const v128_t acc = wasm_i32x4_add(
401+
wasm_i32x4_add(
402+
wasm_i32x4_dot_i16x8(v0s_l, ylo_l),
403+
wasm_i32x4_dot_i16x8(v0s_h, ylo_h)),
404+
wasm_i32x4_add(
405+
wasm_i32x4_dot_i16x8(v1s_l, yhi_l),
406+
wasm_i32x4_dot_i16x8(v1s_h, yhi_h)));
407+
408+
sumv = wasm_f32x4_add(sumv,
409+
wasm_f32x4_mul(
410+
wasm_f32x4_convert_i32x4(acc),
411+
wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
412+
}
413+
414+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
415+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
416+
417+
*s = sumf;
418+
419+
#else
420+
UNUSED(nb);
421+
UNUSED(x);
422+
UNUSED(y);
423+
UNUSED(sumf);
424+
425+
ggml_vec_dot_q4_1_q8_1_generic(
426+
n, s, bs, vx, bx, vy, by, nrc);
427+
#endif
428+
}
429+
358430
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
359431
const int qk = QK8_0;
360432
const int nb = n / qk;

0 commit comments

Comments
 (0)