diff --git a/kernel/arm64/sbgemv_n_neon.c b/kernel/arm64/sbgemv_n_neon.c index 489d4d22cb..ff730407fd 100644 --- a/kernel/arm64/sbgemv_n_neon.c +++ b/kernel/arm64/sbgemv_n_neon.c @@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) { x += 4; } - if (rest_n & 3) { - x[0] *= beta; - if ((rest_n & 3) > 1) - x[1] *= beta; - if ((rest_n & 3) > 2) - x[2] *= beta; + for (BLASLONG i = 0; i < (rest_n & 3); i ++) { + x[i] *= beta; } } return; @@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7; bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7; + bfloat16x8_t x_vec; + bfloat16x4_t x_vecx4; + float32x4_t y1_vec, y2_vec; float32x4_t fp32_low, fp32_high; @@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, if (incx == 1 && incy == 1) { if (beta != 1) { - beta_op(y, n, beta); + beta_op(y, m, beta); } for (i = 0; i < n / 8; i++) { @@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, a_ptr += 4 * lda; - bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr); + x_vecx4 = vld1_bf16(x_ptr); if (alpha != 1) { - x_vec = vcombine_bf16(x_vecx4, bf16_zero); - fp32_low = vreinterpretq_f32_u16( - vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), - vreinterpretq_u16_bf16(x_vec))); + fp32_low = vcvt_f32_bf16(x_vecx4); fp32_low = vmulq_n_f32(fp32_low, alpha); x_vecx4 = vcvt_bf16_f32(fp32_low); } @@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, y1_vec = vld1q_f32(y_ptr); - a0 = vcombine_bf16(a0x4, bf16_zero); - a1 = vcombine_bf16(a1x4, bf16_zero); - a2 = vcombine_bf16(a2x4, bf16_zero); - a3 = vcombine_bf16(a3x4, bf16_zero); + a0 = vcombine_bf16(a0x4, a2x4); + a1 = vcombine_bf16(a1x4, a3x4); - t0 = vreinterpretq_bf16_u16( - vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); - t1 = vreinterpretq_bf16_u16( - vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); + t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); + t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); @@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } if (rest_m) { - x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); - x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); - x2 = alpha * vcvtah_f32_bf16(x_ptr[2]); - x3 = alpha * vcvtah_f32_bf16(x_ptr[3]); + fp32_low = vcvt_f32_bf16(x_vecx4); + + x0 = vgetq_lane_f32(fp32_low, 0); + x1 = vgetq_lane_f32(fp32_low, 1); + x2 = vgetq_lane_f32(fp32_low, 2); + x3 = vgetq_lane_f32(fp32_low, 3); for (BLASLONG j = 0; j < rest_m; j++) { y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]); @@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, a_ptr += 2 * lda; - bfloat16_t tmp_buffer[4]; - memset((void*)tmp_buffer, 0, sizeof(bfloat16_t)); - - tmp_buffer[0] = x_ptr[0]; - tmp_buffer[1] = x_ptr[1]; + x_vecx4 = vreinterpret_bf16_u16(vzip1_u16( + vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])), + vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1])) + )); - bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer); if (alpha != 1) { - x_vec = vcombine_bf16(x_vecx4, bf16_zero); - fp32_low = vreinterpretq_f32_u16( - vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q), - vreinterpretq_u16_bf16(x_vec))); + fp32_low = vcvt_f32_bf16(x_vecx4); fp32_low = vmulq_n_f32(fp32_low, alpha); x_vecx4 = vcvt_bf16_f32(fp32_low); } @@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, t0 = vreinterpretq_bf16_u16( vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); - t4 = vreinterpretq_bf16_u16( + t1 = vreinterpretq_bf16_u16( vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); - y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0); - y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1); + y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0); + y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1); vst1q_f32(y_ptr, y1_vec); vst1q_f32(y_ptr + 4, y2_vec); @@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, a0 = vcombine_bf16(a0x4, bf16_zero); a1 = vcombine_bf16(a1x4, bf16_zero); - t0 = vreinterpretq_bf16_u16( - vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); - t1 = vreinterpretq_bf16_u16( - vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3))); + t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1))); y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0); y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1); - y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2); - y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3); vst1q_f32(y_ptr, y1_vec); a_ptr0 += 4; a_ptr1 += 4; - a_ptr2 += 4; - a_ptr3 += 4; y_ptr += 4; } if (m & 2) { - x0 = alpha * (vcvtah_f32_bf16(x_ptr[0])); - x1 = alpha * (vcvtah_f32_bf16(x_ptr[1])); + fp32_low = vcvt_f32_bf16(x_vecx4); + x0 = vgetq_lane_f32(fp32_low, 0); + x1 = vgetq_lane_f32(fp32_low, 1); + y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]); @@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda, } if (m & 1) { - x0 = alpha * vcvtah_f32_bf16(x_ptr[0]); - x1 = alpha * vcvtah_f32_bf16(x_ptr[1]); + fp32_low = vcvt_f32_bf16(x_vecx4); + x0 = vgetq_lane_f32(fp32_low, 0); + x1 = vgetq_lane_f32(fp32_low, 1); y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]); y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);