@@ -33,16 +33,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333#include "common.h"
3434#include <arm_neon.h>
3535
36- #if (defined(__GNUC__ ) && __GNUC__ >= 13 )
37- #define BF16_TO_FP32 (bf16 ) ((float)(bf16))
38- #else
39- static inline float bf16_to_fp32 (bfloat16_t bf16 ) {
40- uint32_t fp32 = (uint32_t )(* ((u_int16_t * )(& bf16 ))) << 16 ;
41- return * ((float * )& fp32 );
42- }
43- #define BF16_TO_FP32 (bf16 ) bf16_to_fp32(bf16)
44- #endif
45-
4636static void beta_op (float * x , BLASLONG n , FLOAT beta ) {
4737 if (beta == 0 ) {
4838 memset (x , 0 , n * sizeof (float ));
@@ -268,24 +258,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
268258 }
269259
270260 if (rest_m ) {
271- x0 = alpha * BF16_TO_FP32 (x_ptr [0 ]);
272- x1 = alpha * BF16_TO_FP32 (x_ptr [1 ]);
273- x2 = alpha * BF16_TO_FP32 (x_ptr [2 ]);
274- x3 = alpha * BF16_TO_FP32 (x_ptr [3 ]);
275- x4 = alpha * BF16_TO_FP32 (x_ptr [4 ]);
276- x5 = alpha * BF16_TO_FP32 (x_ptr [5 ]);
277- x6 = alpha * BF16_TO_FP32 (x_ptr [6 ]);
278- x7 = alpha * BF16_TO_FP32 (x_ptr [7 ]);
261+ x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
262+ x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
263+ x2 = alpha * vcvtah_f32_bf16 (x_ptr [2 ]);
264+ x3 = alpha * vcvtah_f32_bf16 (x_ptr [3 ]);
265+ x4 = alpha * vcvtah_f32_bf16 (x_ptr [4 ]);
266+ x5 = alpha * vcvtah_f32_bf16 (x_ptr [5 ]);
267+ x6 = alpha * vcvtah_f32_bf16 (x_ptr [6 ]);
268+ x7 = alpha * vcvtah_f32_bf16 (x_ptr [7 ]);
279269
280270 for (BLASLONG j = 0 ; j < rest_m ; j ++ ) {
281- y_ptr [j ] += x0 * BF16_TO_FP32 (a_ptr0 [j ]);
282- y_ptr [j ] += x1 * BF16_TO_FP32 (a_ptr1 [j ]);
283- y_ptr [j ] += x2 * BF16_TO_FP32 (a_ptr2 [j ]);
284- y_ptr [j ] += x3 * BF16_TO_FP32 (a_ptr3 [j ]);
285- y_ptr [j ] += x4 * BF16_TO_FP32 (a_ptr4 [j ]);
286- y_ptr [j ] += x5 * BF16_TO_FP32 (a_ptr5 [j ]);
287- y_ptr [j ] += x6 * BF16_TO_FP32 (a_ptr6 [j ]);
288- y_ptr [j ] += x7 * BF16_TO_FP32 (a_ptr7 [j ]);
271+ y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
272+ y_ptr [j ] += x1 * vcvtah_f32_bf16 (a_ptr1 [j ]);
273+ y_ptr [j ] += x2 * vcvtah_f32_bf16 (a_ptr2 [j ]);
274+ y_ptr [j ] += x3 * vcvtah_f32_bf16 (a_ptr3 [j ]);
275+ y_ptr [j ] += x4 * vcvtah_f32_bf16 (a_ptr4 [j ]);
276+ y_ptr [j ] += x5 * vcvtah_f32_bf16 (a_ptr5 [j ]);
277+ y_ptr [j ] += x6 * vcvtah_f32_bf16 (a_ptr6 [j ]);
278+ y_ptr [j ] += x7 * vcvtah_f32_bf16 (a_ptr7 [j ]);
289279 }
290280 }
291281
@@ -384,16 +374,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
384374 }
385375
386376 if (rest_m ) {
387- x0 = alpha * BF16_TO_FP32 (x_ptr [0 ]);
388- x1 = alpha * BF16_TO_FP32 (x_ptr [1 ]);
389- x2 = alpha * BF16_TO_FP32 (x_ptr [2 ]);
390- x3 = alpha * BF16_TO_FP32 (x_ptr [3 ]);
377+ x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
378+ x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
379+ x2 = alpha * vcvtah_f32_bf16 (x_ptr [2 ]);
380+ x3 = alpha * vcvtah_f32_bf16 (x_ptr [3 ]);
391381
392382 for (BLASLONG j = 0 ; j < rest_m ; j ++ ) {
393- y_ptr [j ] += x0 * BF16_TO_FP32 (a_ptr0 [j ]);
394- y_ptr [j ] += x1 * BF16_TO_FP32 (a_ptr1 [j ]);
395- y_ptr [j ] += x2 * BF16_TO_FP32 (a_ptr2 [j ]);
396- y_ptr [j ] += x3 * BF16_TO_FP32 (a_ptr3 [j ]);
383+ y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
384+ y_ptr [j ] += x1 * vcvtah_f32_bf16 (a_ptr1 [j ]);
385+ y_ptr [j ] += x2 * vcvtah_f32_bf16 (a_ptr2 [j ]);
386+ y_ptr [j ] += x3 * vcvtah_f32_bf16 (a_ptr3 [j ]);
397387 }
398388 }
399389
@@ -480,13 +470,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
480470 }
481471
482472 if (m & 2 ) {
483- x0 = alpha * (BF16_TO_FP32 (x_ptr [0 ]));
484- x1 = alpha * (BF16_TO_FP32 (x_ptr [1 ]));
473+ x0 = alpha * (vcvtah_f32_bf16 (x_ptr [0 ]));
474+ x1 = alpha * (vcvtah_f32_bf16 (x_ptr [1 ]));
485475
486- y_ptr [0 ] += x0 * BF16_TO_FP32 (a_ptr0 [0 ]);
487- y_ptr [0 ] += x1 * BF16_TO_FP32 (a_ptr1 [0 ]);
488- y_ptr [1 ] += x0 * BF16_TO_FP32 (a_ptr0 [1 ]);
489- y_ptr [1 ] += x1 * BF16_TO_FP32 (a_ptr1 [1 ]);
476+ y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
477+ y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
478+ y_ptr [1 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [1 ]);
479+ y_ptr [1 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [1 ]);
490480
491481 a_ptr0 += 2 ;
492482 a_ptr1 += 2 ;
@@ -495,23 +485,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
495485 }
496486
497487 if (m & 1 ) {
498- x0 = alpha * BF16_TO_FP32 (x_ptr [0 ]);
499- x1 = alpha * BF16_TO_FP32 (x_ptr [1 ]);
488+ x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
489+ x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
500490
501- y_ptr [0 ] += x0 * BF16_TO_FP32 (a_ptr0 [0 ]);
502- y_ptr [0 ] += x1 * BF16_TO_FP32 (a_ptr1 [0 ]);
491+ y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
492+ y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
503493 }
504494
505495 x_ptr += 2 ;
506496 }
507497
508498 if (n & 1 ) {
509- x0 = BF16_TO_FP32 (x_ptr [0 ]) * alpha ;
499+ x0 = vcvtah_f32_bf16 (x_ptr [0 ]) * alpha ;
510500 y_ptr = y ;
511501 a_ptr0 = a_ptr ;
512502
513503 for (j = 0 ; j < m ; j ++ ) {
514- y_ptr [j ] += x0 * BF16_TO_FP32 (a_ptr0 [j ]);
504+ y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
515505 }
516506 }
517507
@@ -525,10 +515,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
525515 }
526516
527517 for (j = 0 ; j < n ; j ++ ) {
528- x0 = alpha * BF16_TO_FP32 (* x_ptr );
518+ x0 = alpha * vcvtah_f32_bf16 (* x_ptr );
529519 iy = 0 ;
530520 for (i = 0 ; i < m ; i ++ ) {
531- y [iy ] += x0 * BF16_TO_FP32 (a_ptr [i ]);
521+ y [iy ] += x0 * vcvtah_f32_bf16 (a_ptr [i ]);
532522 iy += incy ;
533523 }
534524
0 commit comments