Skip to content

Commit ef9eb9d

Browse files
authored
armv8.4 bf16 gemm optimization (#6714)
1 parent 5a4a483 commit ef9eb9d

8 files changed

Lines changed: 10201 additions & 4129 deletions

File tree

src/layer/arm/arm_usability.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ static inline signed char float2int8(float v)
1717

1818
static inline uint16x4_t float2bfloat(float32x4_t _v)
1919
{
20+
#if __ARM_FEATURE_BF16_VECTOR_ARITHMETIC
21+
return (uint16x4_t)vcvt_bf16_f32(_v);
22+
#else
2023
return vshrn_n_u32(vreinterpretq_u32_f32(_v), 16);
24+
#endif
2125
}
2226
static inline float32x4_t bfloat2float(uint16x4_t _v)
2327
{

src/layer/arm/gemm_arm.cpp

Lines changed: 46 additions & 157 deletions
Large diffs are not rendered by default.

src/layer/arm/gemm_arm_asimdhp.cpp

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
namespace ncnn {
1515

16-
#include "gemm_bf16s_fp16s.h"
1716
#include "gemm_fp16s.h"
1817

1918
#if NCNN_INT8
@@ -2378,11 +2377,11 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl
23782377

23792378
if (transB)
23802379
{
2381-
pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
2380+
pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
23822381
}
23832382
else
23842383
{
2385-
transpose_pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
2384+
transpose_pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
23862385
}
23872386
}
23882387

@@ -2415,7 +2414,7 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl
24152414

24162415
if (broadcast_type_C == 3)
24172416
{
2418-
pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj);
2417+
pack_A_tile_fp16(C, topT_tile, i, max_ii, j, max_jj);
24192418
}
24202419

24212420
const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C;
@@ -2434,11 +2433,11 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl
24342433
{
24352434
if (transA)
24362435
{
2437-
transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
2436+
transpose_pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
24382437
}
24392438
else
24402439
{
2441-
pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
2440+
pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
24422441
}
24432442
}
24442443

@@ -2449,7 +2448,7 @@ static int gemm_arm_fp16sa(const Mat& A, const Mat& B, const Mat& C, Mat& top_bl
24492448

24502449
if (output_transpose)
24512450
{
2452-
transpose_unpack_output_tile_bf16_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
2451+
transpose_unpack_output_tile_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
24532452
}
24542453
}
24552454
}
@@ -2495,11 +2494,11 @@ static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& to
24952494

24962495
if (transB)
24972496
{
2498-
pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
2497+
pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
24992498
}
25002499
else
25012500
{
2502-
transpose_pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
2501+
transpose_pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
25032502
}
25042503
}
25052504

@@ -2528,7 +2527,7 @@ static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& to
25282527

25292528
if (broadcast_type_C == 3)
25302529
{
2531-
pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj);
2530+
pack_A_tile_fp16(C, topT_tile, i, max_ii, j, max_jj);
25322531
}
25332532

25342533
const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C;
@@ -2550,7 +2549,7 @@ static int gemm_AT_arm_fp16sa(const Mat& AT, const Mat& B, const Mat& C, Mat& to
25502549

25512550
if (output_transpose)
25522551
{
2553-
transpose_unpack_output_tile_bf16_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
2552+
transpose_unpack_output_tile_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
25542553
}
25552554
}
25562555
}
@@ -2605,7 +2604,7 @@ static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& to
26052604

26062605
if (broadcast_type_C == 3)
26072606
{
2608-
pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj);
2607+
pack_A_tile_fp16(C, topT_tile, i, max_ii, j, max_jj);
26092608
}
26102609

26112610
const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C;
@@ -2624,11 +2623,11 @@ static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& to
26242623
{
26252624
if (transA)
26262625
{
2627-
transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
2626+
transpose_pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
26282627
}
26292628
else
26302629
{
2631-
pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
2630+
pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
26322631
}
26332632
}
26342633

@@ -2639,7 +2638,7 @@ static int gemm_BT_arm_fp16sa(const Mat& A, const Mat& BT, const Mat& C, Mat& to
26392638

26402639
if (output_transpose)
26412640
{
2642-
transpose_unpack_output_tile_bf16_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
2641+
transpose_unpack_output_tile_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
26432642
}
26442643
}
26452644
}
@@ -2684,7 +2683,7 @@ static int gemm_AT_BT_arm_fp16sa(const Mat& AT, const Mat& BT, const Mat& C, Mat
26842683

26852684
if (broadcast_type_C == 3)
26862685
{
2687-
pack_A_tile_bf16_fp16(C, topT_tile, i, max_ii, j, max_jj);
2686+
pack_A_tile_fp16(C, topT_tile, i, max_ii, j, max_jj);
26882687
}
26892688

26902689
const Mat& CT_tile = broadcast_type_C == 3 ? topT_tile : C;
@@ -2706,7 +2705,7 @@ static int gemm_AT_BT_arm_fp16sa(const Mat& AT, const Mat& BT, const Mat& C, Mat
27062705

27072706
if (output_transpose)
27082707
{
2709-
transpose_unpack_output_tile_bf16_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
2708+
transpose_unpack_output_tile_fp16(topT_tile, top_blob, i, max_ii, j, max_jj);
27102709
}
27112710
}
27122711
}

src/layer/arm/gemm_arm_bf16.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2026 Tencent
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
#include "cpu.h"
5+
#include "mat.h"
6+
#include "arm_usability.h"
7+
8+
namespace ncnn {
9+
10+
#if NCNN_BF16
11+
#include "gemm_bf16s.h"
12+
13+
void pack_A_tile_bf16_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk)
14+
{
15+
pack_A_tile_bf16(A, AT, i, max_ii, k, max_kk);
16+
}
17+
18+
void transpose_pack_A_tile_bf16_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk)
19+
{
20+
transpose_pack_A_tile_bf16(A, AT, i, max_ii, k, max_kk);
21+
}
22+
23+
void pack_B_tile_bf16_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk)
24+
{
25+
pack_B_tile_bf16(B, BT, j, max_jj, k, max_kk);
26+
}
27+
28+
void transpose_pack_B_tile_bf16_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk)
29+
{
30+
transpose_pack_B_tile_bf16(B, BT, j, max_jj, k, max_kk);
31+
}
32+
33+
void pack_A_tile_fp32_to_bf16_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk)
34+
{
35+
pack_A_tile_fp32_to_bf16(A, AT, i, max_ii, k, max_kk);
36+
}
37+
38+
void transpose_pack_A_tile_fp32_to_bf16_bf16(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk)
39+
{
40+
transpose_pack_A_tile_fp32_to_bf16(A, AT, i, max_ii, k, max_kk);
41+
}
42+
43+
void pack_B_tile_fp32_to_bf16_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk)
44+
{
45+
pack_B_tile_fp32_to_bf16(B, BT, j, max_jj, k, max_kk);
46+
}
47+
48+
void transpose_pack_B_tile_fp32_to_bf16_bf16(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk)
49+
{
50+
transpose_pack_B_tile_fp32_to_bf16(B, BT, j, max_jj, k, max_kk);
51+
}
52+
53+
void unpack_output_tile_fp32_to_bf16_bf16(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, float alpha, float beta, int output_transpose)
54+
{
55+
unpack_output_tile_fp32_to_bf16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, alpha, beta, output_transpose);
56+
}
57+
58+
void gemm_transB_packed_tile_bf16s_bf16(const Mat& AT_tile, const Mat& BT_tile, Mat& topT_tile, int max_ii, int max_jj, int k, int max_kk)
59+
{
60+
gemm_transB_packed_tile_bf16s(AT_tile, BT_tile, topT_tile, max_ii, max_jj, k, max_kk);
61+
}
62+
#endif // NCNN_BF16
63+
64+
} // namespace ncnn

src/layer/arm/gemm_arm_vfpv4.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
namespace ncnn {
1515

16-
#include "gemm_bf16s_fp16s.h"
1716
#include "gemm_fp16s.h"
1817

1918
#if NCNN_INT8
@@ -31,7 +30,7 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
3130
// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
3231

3332
int TILE_M, TILE_N, TILE_K;
34-
get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
33+
get_optimal_tile_mnk_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
3534

3635
// NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
3736

@@ -65,11 +64,11 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
6564

6665
if (transB)
6766
{
68-
pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
67+
pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
6968
}
7069
else
7170
{
72-
transpose_pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
71+
transpose_pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
7372
}
7473
}
7574

@@ -121,11 +120,11 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
121120
{
122121
if (transA)
123122
{
124-
transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
123+
transpose_pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
125124
}
126125
else
127126
{
128-
pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
127+
pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
129128
}
130129
}
131130

@@ -152,7 +151,7 @@ static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top
152151
// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
153152

154153
int TILE_M, TILE_N, TILE_K;
155-
get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
154+
get_optimal_tile_mnk_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
156155

157156
// NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
158157

@@ -183,11 +182,11 @@ static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top
183182

184183
if (transB)
185184
{
186-
pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
185+
pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
187186
}
188187
else
189188
{
190-
transpose_pack_B_tile_bf16_fp16(B, BT_tile, j, max_jj, k, max_kk);
189+
transpose_pack_B_tile_fp16(B, BT_tile, j, max_jj, k, max_kk);
191190
}
192191
}
193192

@@ -254,7 +253,7 @@ static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
254253
// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
255254

256255
int TILE_M, TILE_N, TILE_K;
257-
get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
256+
get_optimal_tile_mnk_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
258257

259258
// NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
260259

@@ -313,11 +312,11 @@ static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
313312
{
314313
if (transA)
315314
{
316-
transpose_pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
315+
transpose_pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
317316
}
318317
else
319318
{
320-
pack_A_tile_bf16_fp16(A, AT_tile, i, max_ii, k, max_kk);
319+
pack_A_tile_fp16(A, AT_tile, i, max_ii, k, max_kk);
321320
}
322321
}
323322

@@ -342,7 +341,7 @@ static int gemm_AT_BT_arm_fp16s(const Mat& AT, const Mat& BT, const Mat& C, Mat&
342341
// NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
343342

344343
int TILE_M, TILE_N, TILE_K;
345-
get_optimal_tile_mnk_bf16s_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
344+
get_optimal_tile_mnk_fp16s(M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
346345

347346
// NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
348347

@@ -413,7 +412,7 @@ int Gemm_arm::create_pipeline_fp16s(const Option& opt)
413412
const int K = constantK;
414413

415414
int TILE_M, TILE_N, TILE_K;
416-
get_optimal_tile_mnk_bf16s_fp16s(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads);
415+
get_optimal_tile_mnk_fp16s(M, 0, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads);
417416

418417
const int nn_M = (M + TILE_M - 1) / TILE_M;
419418

@@ -454,7 +453,7 @@ int Gemm_arm::create_pipeline_fp16s(const Option& opt)
454453
const int K = constantK;
455454

456455
int TILE_M, TILE_N, TILE_K;
457-
get_optimal_tile_mnk_bf16s_fp16s(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads);
456+
get_optimal_tile_mnk_fp16s(0, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads);
458457

459458
const int nn_N = (N + TILE_N - 1) / TILE_N;
460459

0 commit comments

Comments
 (0)