@@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
50235023 UNUSED (ncols_interleaved);
50245024 UNUSED (blocklen);
50255025
5026+ #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
5027+ if (svcntb () * 8 == 256 ) {
5028+ const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
5029+
5030+ static const uint32_t idx_arr[8 ] = {0 , 1 , 4 , 5 , 2 , 3 , 6 , 7 };
5031+ svuint32_t idx = svld1 (svptrue_b32 (), idx_arr);
5032+ static const uint32_t idx_arr1[8 ] = {0 , 1 , 2 , 3 , 1 , 2 , 3 , 0 };
5033+ svuint32_t idx_sc1 = svld1 (svptrue_b32 (), idx_arr1);
5034+ static const uint32_t idx_arr2[8 ] = {0 , 1 , 2 , 3 , 0 , 1 , 2 , 3 };
5035+ svuint32_t idx_sc2 = svld1 (svptrue_b32 (), idx_arr2);
5036+
5037+ for (int y = 0 ; y < nr; y += 4 ) {
5038+ const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4 ) * nb;
5039+
5040+ for (int x = 0 ; x < nc; x += ncols_interleaved) {
5041+ const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4 ) * nb;
5042+ const block_q8_0x4 * a_ptr = a_ptr_base;
5043+
5044+ svfloat32_t acc_f32_01 = svdup_f32 (0 );
5045+ svfloat32_t acc_f32_23 = svdup_f32 (0 );
5046+
5047+ for (int b = 0 ; b < nb; b++) {
5048+
5049+ svint32_t acc_01 = svdup_s32 (0 );
5050+ svint32_t acc_23 = svdup_s32 (0 );
5051+
5052+ // Process 4 chunks of 8 positions each
5053+ for (int chunk = 0 ; chunk < 4 ; chunk++) {
5054+ svint8_t s_a01 = svld1rq_s8 (svptrue_b8 (), a_ptr->qs + chunk * 32 );
5055+ svint8_t s_a23 = svld1rq_s8 (svptrue_b8 (), a_ptr->qs + chunk * 32 + 16 );
5056+ svint8_t s_b0123 = svld1_s8 (svptrue_b8 (), b_ptr->qs + chunk * 32 );
5057+
5058+ acc_01 = svmmla_s32 (acc_01, s_a01, s_b0123);
5059+ acc_23 = svmmla_s32 (acc_23, s_a23, s_b0123);
5060+ }
5061+
5062+ // Reorder outputs from 2×2 tiles to row-major
5063+ // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3]
5064+ // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3]
5065+
5066+ svint32_t row01 = svtbl_s32 (acc_01, idx);
5067+ svint32_t row23 = svtbl_s32 (acc_23, idx);
5068+
5069+ svfloat16_t temp1 = svld1_f16 (svptrue_pat_b16 (SV_VL4), (const __fp16 *) a_ptr->d );
5070+ svfloat16_t temp2 = svld1_f16 (svptrue_pat_b16 (SV_VL4), (const __fp16 *) b_ptr->d );
5071+ svfloat32_t sv_a_d = svtbl_f32 (svcvt_f32_f16_x (svptrue_b32 (), svzip1_f16 (temp1, temp1)), idx_sc1);
5072+ svfloat32_t sv_b_d = svtbl_f32 (svcvt_f32_f16_x (svptrue_b32 (), svzip1_f16 (temp2, temp2)), idx_sc2);
5073+
5074+ acc_f32_01 = svmla_f32_x (svptrue_b32 (), acc_f32_01, svcvt_f32_s32_x (svptrue_b32 (), row01), svmul_lane_f32 (sv_b_d, sv_a_d, 0 ));
5075+ acc_f32_23 = svmla_f32_x (svptrue_b32 (), acc_f32_23, svcvt_f32_s32_x (svptrue_b32 (), row23), svmul_lane_f32 (sv_b_d, sv_a_d, 2 ));
5076+ a_ptr++;
5077+ b_ptr++;
5078+ }
5079+
5080+ svbool_t pg4 = svptrue_pat_b32 (SV_VL4);
5081+ svst1_f32 (pg4, s + (y+0 ) * bs + x, acc_f32_01);
5082+ svst1_f32 (pg4, s + (y+1 ) * bs + x, svext_f32 (acc_f32_01, acc_f32_01, 4 ));
5083+ svst1_f32 (pg4, s + (y+2 ) * bs + x, acc_f32_23);
5084+ svst1_f32 (pg4, s + (y+3 ) * bs + x, svext_f32 (acc_f32_23, acc_f32_23, 4 ));
5085+ }
5086+ }
5087+ return ;
5088+ }
5089+ #endif // SVE compile-time end
5090+
50265091#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
50275092 const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
50285093
0 commit comments