Skip to content

Commit 46096c4

Browse files
Mousiustaoye9
andcommitted
Fix bf16->f32 conversion for NEOVERSEV1 target
This fixes an issue originally introduced with the BGEMM kernel when I was tweaking it. #5287 didn't suffer from this bug. I've updated the tests to run with `beta=1.0` so as to test loading and updating from C. Alongside this, the tests now return sensible return values to reduce the risk of them being ignored. Co-authored-by: Ye Tao <ye.tao@arm.com>
1 parent e939c6c commit 46096c4

5 files changed

Lines changed: 29 additions & 12 deletions

File tree

kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040

4141
#define UPDATE_C(PG, PTR, DST, SRC) \
4242
do { \
43-
DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \
43+
svtmp16 = svld1_bf16((pghalf), (PTR)); \
44+
DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \
4445
DST = svadd_z((PG), SRC, DST); \
4546
svtmp16 = svcvt_bf16_f32_z((PG), DST); \
4647
svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \
@@ -55,7 +56,8 @@
5556

5657
#define UPDATE_C(PG, PTR, DST, SRC) \
5758
do { \
58-
DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \
59+
svtmp16 = svld1_bf16((pghalf), (PTR)); \
60+
DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \
5961
DST = svmad_z((PG), svalpha, SRC, DST); \
6062
svtmp16 = svcvt_bf16_f32_z((PG), DST); \
6163
svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \
@@ -133,6 +135,7 @@ static int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k,
133135
OUTPUT_FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3;
134136
svfloat32_t tmp0, tmp1, tmp2, tmp3;
135137
#ifdef BGEMM
138+
svbfloat16_t zeros = svdup_n_bf16(TO16(0.0));
136139
svbfloat16_t svtmp16;
137140
#else
138141
float32x2_t tmp4, tmp5, tmp6, tmp7;

test/compare_sgemm_bgemm.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ main (int argc, char *argv[])
4444
int ret = 0;
4545
int loop = BGEMM_LARGEST;
4646
char transA = 'N', transB = 'N';
47-
float alpha = 1.0, beta = 0.0;
47+
float alpha = 1.0, beta = 1.0;
4848
bfloat16 alpha_bf16;
4949
sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one);
5050
bfloat16 beta_bf16;
@@ -94,9 +94,15 @@ main (int argc, char *argv[])
9494
transB = 'T';
9595
}
9696

97-
memset(CC, 0, m * n * sizeof(bfloat16));
98-
memset(DD, 0, m * n * sizeof(FLOAT));
99-
memset(C, 0, m * n * sizeof(FLOAT));
97+
for (j = 0; j < m; j++)
98+
{
99+
for (i = 0; i < n; i++)
100+
{
101+
C[j * n + i] = 100.0;
102+
DD[j * n + i] = 100.0;
103+
sbstobf16_(&one, &C[j * n + i], &one, &CC[j * n + i], &one);
104+
}
105+
}
100106

101107
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
102108
&m, B, &k, &beta, C, &m);
@@ -152,7 +158,8 @@ main (int argc, char *argv[])
152158
}
153159

154160
if (ret != 0) {
155-
fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret);
161+
fprintf(stderr, "BGEMM FAILURES: %d\n", ret);
162+
return 1;
156163
}
157164

158165
return ret;

test/compare_sgemm_sbgemm.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ main (int argc, char *argv[])
140140
}
141141

142142
if (ret != 0) {
143-
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
143+
fprintf(stderr, "SBGEMM FAILURES: %d\n", ret);
144+
return 1;
144145
}
145146

146147
return ret;

test/compare_sgemv_bgemv.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ int main(int argc, char *argv[])
147147
} // alpha
148148
} // beta
149149

150-
if (ret != 0)
151-
fprintf(stderr, "FATAL ERROR BGEMV - Return code: %d\n", ret);
150+
if (ret != 0) {
151+
fprintf(stderr, "BGEMV FAILURES: %d\n", ret);
152+
return 1;
153+
}
154+
152155
return ret;
153156
}

test/compare_sgemv_sbgemv.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ main (int argc, char *argv[])
122122
} // alpha
123123
} // beta
124124

125-
if (ret != 0)
126-
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
125+
if (ret != 0) {
126+
fprintf(stderr, "SBGEMV FAILURES: %d\n", ret);
127+
return 1;
128+
}
129+
127130
return ret;
128131
}

0 commit comments

Comments
 (0)