1313
1414namespace ncnn {
1515
16- #include " gemm_bf16s_fp16s.h"
1716#include " gemm_fp16s.h"
1817
1918#if NCNN_INT8
@@ -31,7 +30,7 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
3130 // NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
3231
3332 int TILE_M, TILE_N, TILE_K;
34- get_optimal_tile_mnk_bf16s_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
33+ get_optimal_tile_mnk_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
3534
3635 // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
3736
@@ -65,11 +64,11 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
6564
6665 if (transB)
6766 {
68- pack_B_tile_bf16_fp16 (B, BT_tile, j, max_jj, k, max_kk);
67+ pack_B_tile_fp16 (B, BT_tile, j, max_jj, k, max_kk);
6968 }
7069 else
7170 {
72- transpose_pack_B_tile_bf16_fp16 (B, BT_tile, j, max_jj, k, max_kk);
71+ transpose_pack_B_tile_fp16 (B, BT_tile, j, max_jj, k, max_kk);
7372 }
7473 }
7574
@@ -121,11 +120,11 @@ static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blo
121120 {
122121 if (transA)
123122 {
124- transpose_pack_A_tile_bf16_fp16 (A, AT_tile, i, max_ii, k, max_kk);
123+ transpose_pack_A_tile_fp16 (A, AT_tile, i, max_ii, k, max_kk);
125124 }
126125 else
127126 {
128- pack_A_tile_bf16_fp16 (A, AT_tile, i, max_ii, k, max_kk);
127+ pack_A_tile_fp16 (A, AT_tile, i, max_ii, k, max_kk);
129128 }
130129 }
131130
@@ -152,7 +151,7 @@ static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top
152151 // NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
153152
154153 int TILE_M, TILE_N, TILE_K;
155- get_optimal_tile_mnk_bf16s_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
154+ get_optimal_tile_mnk_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
156155
157156 // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
158157
@@ -183,11 +182,11 @@ static int gemm_AT_arm_fp16s(const Mat& AT, const Mat& B, const Mat& C, Mat& top
183182
184183 if (transB)
185184 {
186- pack_B_tile_bf16_fp16 (B, BT_tile, j, max_jj, k, max_kk);
185+ pack_B_tile_fp16 (B, BT_tile, j, max_jj, k, max_kk);
187186 }
188187 else
189188 {
190- transpose_pack_B_tile_bf16_fp16 (B, BT_tile, j, max_jj, k, max_kk);
189+ transpose_pack_B_tile_fp16 (B, BT_tile, j, max_jj, k, max_kk);
191190 }
192191 }
193192
@@ -254,7 +253,7 @@ static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
254253 // NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
255254
256255 int TILE_M, TILE_N, TILE_K;
257- get_optimal_tile_mnk_bf16s_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
256+ get_optimal_tile_mnk_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
258257
259258 // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
260259
@@ -313,11 +312,11 @@ static int gemm_BT_arm_fp16s(const Mat& A, const Mat& BT, const Mat& C, Mat& top
313312 {
314313 if (transA)
315314 {
316- transpose_pack_A_tile_bf16_fp16 (A, AT_tile, i, max_ii, k, max_kk);
315+ transpose_pack_A_tile_fp16 (A, AT_tile, i, max_ii, k, max_kk);
317316 }
318317 else
319318 {
320- pack_A_tile_bf16_fp16 (A, AT_tile, i, max_ii, k, max_kk);
319+ pack_A_tile_fp16 (A, AT_tile, i, max_ii, k, max_kk);
321320 }
322321 }
323322
@@ -342,7 +341,7 @@ static int gemm_AT_BT_arm_fp16s(const Mat& AT, const Mat& BT, const Mat& C, Mat&
342341 // NCNN_LOGE("M/N/K = %d %d %d", M, N, K);
343342
344343 int TILE_M, TILE_N, TILE_K;
345- get_optimal_tile_mnk_bf16s_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
344+ get_optimal_tile_mnk_fp16s (M, N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, nT);
346345
347346 // NCNN_LOGE("TILE M/N/K = %d %d %d", TILE_M, TILE_N, TILE_K);
348347
@@ -413,7 +412,7 @@ int Gemm_arm::create_pipeline_fp16s(const Option& opt)
413412 const int K = constantK;
414413
415414 int TILE_M, TILE_N, TILE_K;
416- get_optimal_tile_mnk_bf16s_fp16s (M, 0 , K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads );
415+ get_optimal_tile_mnk_fp16s (M, 0 , K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads );
417416
418417 const int nn_M = (M + TILE_M - 1 ) / TILE_M;
419418
@@ -454,7 +453,7 @@ int Gemm_arm::create_pipeline_fp16s(const Option& opt)
454453 const int K = constantK;
455454
456455 int TILE_M, TILE_N, TILE_K;
457- get_optimal_tile_mnk_bf16s_fp16s (0 , N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads );
456+ get_optimal_tile_mnk_fp16s (0 , N, K, constant_TILE_M, constant_TILE_N, constant_TILE_K, TILE_M, TILE_N, TILE_K, opt.num_threads );
458457
459458 const int nn_N = (N + TILE_N - 1 ) / TILE_N;
460459
0 commit comments