Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 34 additions & 49 deletions kernel/arm64/sbgemv_n_neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) {
x += 4;
}

if (rest_n & 3) {
x[0] *= beta;
if ((rest_n & 3) > 1)
x[1] *= beta;
if ((rest_n & 3) > 2)
x[2] *= beta;
for (BLASLONG i = 0; i < (rest_n & 3); i ++) {
x[i] *= beta;
}
}
return;
Expand All @@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7;
bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7;

bfloat16x8_t x_vec;
bfloat16x4_t x_vecx4;

float32x4_t y1_vec, y2_vec;
float32x4_t fp32_low, fp32_high;

Expand All @@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

if (incx == 1 && incy == 1) {
if (beta != 1) {
beta_op(y, n, beta);
beta_op(y, m, beta);
}

for (i = 0; i < n / 8; i++) {
Expand Down Expand Up @@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

a_ptr += 4 * lda;

bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr);
x_vecx4 = vld1_bf16(x_ptr);
if (alpha != 1) {
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
fp32_low = vreinterpretq_f32_u16(
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
vreinterpretq_u16_bf16(x_vec)));
fp32_low = vcvt_f32_bf16(x_vecx4);
fp32_low = vmulq_n_f32(fp32_low, alpha);
x_vecx4 = vcvt_bf16_f32(fp32_low);
}
Expand Down Expand Up @@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

y1_vec = vld1q_f32(y_ptr);

a0 = vcombine_bf16(a0x4, bf16_zero);
a1 = vcombine_bf16(a1x4, bf16_zero);
a2 = vcombine_bf16(a2x4, bf16_zero);
a3 = vcombine_bf16(a3x4, bf16_zero);
a0 = vcombine_bf16(a0x4, a2x4);
a1 = vcombine_bf16(a1x4, a3x4);

t0 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t1 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));

y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
Expand All @@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
}

if (rest_m) {
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
fp32_low = vcvt_f32_bf16(x_vecx4);

x0 = vgetq_lane_f32(fp32_low, 0);
x1 = vgetq_lane_f32(fp32_low, 1);
x2 = vgetq_lane_f32(fp32_low, 2);
x3 = vgetq_lane_f32(fp32_low, 3);

for (BLASLONG j = 0; j < rest_m; j++) {
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
Expand All @@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

a_ptr += 2 * lda;

bfloat16_t tmp_buffer[4];
memset((void*)tmp_buffer, 0, sizeof(bfloat16_t));

tmp_buffer[0] = x_ptr[0];
tmp_buffer[1] = x_ptr[1];
x_vecx4 = vreinterpret_bf16_u16(vzip1_u16(
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])),
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1]))
));

bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer);
if (alpha != 1) {
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
fp32_low = vreinterpretq_f32_u16(
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
vreinterpretq_u16_bf16(x_vec)));
fp32_low = vcvt_f32_bf16(x_vecx4);
fp32_low = vmulq_n_f32(fp32_low, alpha);
x_vecx4 = vcvt_bf16_f32(fp32_low);
}
Expand All @@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,

t0 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t4 = vreinterpretq_bf16_u16(
t1 = vreinterpretq_bf16_u16(
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));

y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);

y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0);
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1);
y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0);
y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1);

vst1q_f32(y_ptr, y1_vec);
vst1q_f32(y_ptr + 4, y2_vec);
Expand All @@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
a0 = vcombine_bf16(a0x4, bf16_zero);
a1 = vcombine_bf16(a1x4, bf16_zero);

t0 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
t1 = vreinterpretq_bf16_u16(
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));

y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2);
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3);

vst1q_f32(y_ptr, y1_vec);

a_ptr0 += 4;
a_ptr1 += 4;
a_ptr2 += 4;
a_ptr3 += 4;

y_ptr += 4;
}

if (m & 2) {
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0]));
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1]));
fp32_low = vcvt_f32_bf16(x_vecx4);
x0 = vgetq_lane_f32(fp32_low, 0);
x1 = vgetq_lane_f32(fp32_low, 1);


y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
Expand All @@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
}

if (m & 1) {
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
fp32_low = vcvt_f32_bf16(x_vecx4);
x0 = vgetq_lane_f32(fp32_low, 0);
x1 = vgetq_lane_f32(fp32_low, 1);

y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
Expand Down
Loading