Skip to content

Commit bdc9c74

Browse files
hrushitfujitsuVithulep
andauthored
ggml : add sve tuned code for gemm_q8_0_4x8_q8_0() kernel (#21916)
* Added sve tuned code for gemm_q8_0_4x8_q8_0() kernel * Change arrays to static const in repack.cpp --------- Co-authored-by: Vithulep <prashant.vithule@fujitsu.com>
1 parent 739393b commit bdc9c74

1 file changed

Lines changed: 65 additions & 0 deletions

File tree

ggml/src/ggml-cpu/arch/arm/repack.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
50235023
UNUSED(ncols_interleaved);
50245024
UNUSED(blocklen);
50255025

5026+
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
5027+
if (svcntb() * 8 == 256) {
5028+
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
5029+
5030+
static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7};
5031+
svuint32_t idx = svld1(svptrue_b32(), idx_arr);
5032+
static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0};
5033+
svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1);
5034+
static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3};
5035+
svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2);
5036+
5037+
for (int y = 0; y < nr; y += 4) {
5038+
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
5039+
5040+
for (int x = 0; x < nc; x += ncols_interleaved) {
5041+
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
5042+
const block_q8_0x4 * a_ptr = a_ptr_base;
5043+
5044+
svfloat32_t acc_f32_01 = svdup_f32(0);
5045+
svfloat32_t acc_f32_23 = svdup_f32(0);
5046+
5047+
for (int b = 0; b < nb; b++) {
5048+
5049+
svint32_t acc_01 = svdup_s32(0);
5050+
svint32_t acc_23 = svdup_s32(0);
5051+
5052+
// Process 4 chunks of 8 positions each
5053+
for (int chunk = 0; chunk < 4; chunk++) {
5054+
svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32);
5055+
svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16);
5056+
svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32);
5057+
5058+
acc_01 = svmmla_s32(acc_01, s_a01, s_b0123);
5059+
acc_23 = svmmla_s32(acc_23, s_a23, s_b0123);
5060+
}
5061+
5062+
// Reorder outputs from 2×2 tiles to row-major
5063+
// acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
5064+
// acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
5065+
5066+
svint32_t row01 = svtbl_s32(acc_01, idx);
5067+
svint32_t row23 = svtbl_s32(acc_23, idx);
5068+
5069+
svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d);
5070+
svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d);
5071+
svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1);
5072+
svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2);
5073+
5074+
acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0));
5075+
acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2));
5076+
a_ptr++;
5077+
b_ptr++;
5078+
}
5079+
5080+
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
5081+
svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01);
5082+
svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4));
5083+
svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23);
5084+
svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4));
5085+
}
5086+
}
5087+
return;
5088+
}
5089+
#endif // SVE compile-time end
5090+
50265091
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
50275092
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
50285093

0 commit comments

Comments
 (0)