Skip to content

Commit 63ce52e

Browse files
committed
change data type of bgemm alpha and beta from bfloat16 to fp32 and add makefiles changes for bgemm interface
1 parent 082a9d2 commit 63ce52e

20 files changed

Lines changed: 337 additions & 282 deletions

CONTRIBUTORS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ In chronological order:
251251
* Ye Tao <ye.tao@arm.com>
252252
* [2025-02-03] Optimize SBGEMM kernel on NEOVERSEV1
253253
* [2025-02-27] Add sbgemv_n_neon kernel
254+
* [2025-05-17] Impl prototype of BGEMM inferface
254255

255256
* Abhishek Kumar <https://github.com/abhishek-iitmadras>
256-
* [2025-04-22] Optimise dot kernel for NEOVERSE V1
257+
* [2025-04-22] Optimise dot kernel for NEOVERSE V1

Makefile.system

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,9 @@ ifeq ($(USE_TLS), 1)
15441544
CCOMMON_OPT += -DUSE_TLS
15451545
endif
15461546

1547+
ifeq ($(BUILD_BFLOAT16_ONLY), 1)
1548+
CCOMMON_OPT += -DBUILD_BFLOAT16_ONLY
1549+
endif
15471550
ifeq ($(BUILD_BFLOAT16), 1)
15481551
CCOMMON_OPT += -DBUILD_BFLOAT16
15491552
endif
@@ -1888,6 +1891,7 @@ export FUNCTION_PROFILE
18881891
export TARGET_CORE
18891892
export NO_AVX512
18901893
export NO_AVX2
1894+
export BUILD_BFLOAT16_ONLY
18911895
export BUILD_BFLOAT16
18921896
export NO_LSX
18931897
export NO_LASX
@@ -1912,7 +1916,7 @@ export ZGEMM3M_UNROLL_M
19121916
export ZGEMM3M_UNROLL_N
19131917
export XGEMM3M_UNROLL_M
19141918
export XGEMM3M_UNROLL_N
1915-
1919+
# Todo: add bgemm unroll factors
19161920

19171921
ifdef USE_CUDA
19181922
export CUDADIR

Makefile.tail

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
1111

1212
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
1313

14-
BLASOBJS = $(SBEXTOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
14+
BLASOBJS = $(SBEXTOBJS) $(BBLASOBJS) $(SBBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS) $(CBAUXOBJS)
1515
BLASOBJS_P = $(SBEXTOBJS_P) $(SBBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P) $(CBAUXOBJS_P)
1616

1717
ifdef EXPRECISION
@@ -24,6 +24,7 @@ BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
2424
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
2525
endif
2626

27+
$(BBLASOBJS) : override CFLAGS += -DBFLOAT16_ONLY -UDOUBLE -UCOMPLEX -UBFLOAT16 -USMALL_MATRIX_OPT
2728
$(SBBLASOBJS) $(SBBLASOBJS_P) : override CFLAGS += -DBFLOAT16 -UDOUBLE -UCOMPLEX
2829
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
2930
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
@@ -42,6 +43,7 @@ $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4243
$(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4344
$(SBEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
4445

46+
4547
libs :: $(BLASOBJS) $(COMMONOBJS)
4648
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
4749

cblas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ void cblas_sbgemm_batch(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST en
475475
OPENBLAS_CONST float * alpha_array, OPENBLAS_CONST bfloat16 ** A_array, OPENBLAS_CONST blasint * lda_array, OPENBLAS_CONST bfloat16 ** B_array, OPENBLAS_CONST blasint * ldb_array, OPENBLAS_CONST float * beta_array, float ** C_array, OPENBLAS_CONST blasint * ldc_array, OPENBLAS_CONST blasint group_count, OPENBLAS_CONST blasint * group_size);
476476

477477
void cblas_bgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
478-
OPENBLAS_CONST bfloat16 alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST bfloat16 beta, bfloat16 *C, OPENBLAS_CONST blasint ldc);
478+
OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, bfloat16 *C, OPENBLAS_CONST blasint ldc);
479479

480480
#ifdef __cplusplus
481481
}

common_interface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
481481
xdouble *, blasint *, xdouble *, xdouble *, blasint *);
482482

483483
/* Level 3 routines */
484-
void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, bfloat16 *,
485-
bfloat16 *, blasint *, bfloat16 *, blasint *, bfloat16 *, bfloat16 *, blasint *);
484+
void BLASFUNC(bgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
485+
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, bfloat16 *, blasint *);
486486
void BLASFUNC(sbgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
487487
bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
488488
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,

common_level3.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
5454

5555
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
5656

57-
int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, bfloat16,
57+
int bgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
5858
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
5959
int sbgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
6060
bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
@@ -513,7 +513,7 @@ int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
513513
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
514514

515515
// add bgemm kernel
516-
int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG);
516+
int bgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG);
517517
int sbgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
518518
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
519519
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);

common_param.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ typedef struct {
5454
int bgemm_unroll_m, bgemm_unroll_n, bgemm_unroll_mn;
5555
int bgemm_align_k;
5656

57-
int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG);
58-
int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, bfloat16, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
57+
int (*bgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, bfloat16 *, BLASLONG);
58+
int (*bgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
5959

6060
int (*bgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
6161
int (*bgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);

driver/level3/Makefile

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ ifeq ($(BUILD_BFLOAT16),1)
5252
SBBLASOBJS += sbgemm_nn.$(SUFFIX) sbgemm_nt.$(SUFFIX) sbgemm_tn.$(SUFFIX) sbgemm_tt.$(SUFFIX)
5353
endif
5454

55+
ifeq ($(BUILD_BFLOAT16_ONLY),1)
56+
BBLASOBJS += bgemm_nn.$(SUFFIX) bgemm_nt.$(SUFFIX) bgemm_tn.$(SUFFIX) bgemm_tt.$(SUFFIX)
57+
endif
58+
59+
BLASOBJS += \
60+
gemm_nn.$(SUFFIX) gemm_nt.$(SUFFIX) gemm_tn.$(SUFFIX) gemm_tt.$(SUFFIX)
61+
5562
SBLASOBJS += \
5663
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
5764
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
@@ -376,6 +383,18 @@ endif
376383

377384
all ::
378385

386+
bgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
387+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
388+
389+
bgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
390+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
391+
392+
bgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
393+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
394+
395+
bgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
396+
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
397+
379398
sbgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
380399
$(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
381400

@@ -432,8 +451,8 @@ cgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
432451

433452
cgemm_nr.$(SUFFIX) : gemm.c level3.c ../../param.h
434453
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $(@F)
435-
436454
cgemm_nc.$(SUFFIX) : gemm.c level3.c ../../param.h
455+
437456
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $(@F)
438457

439458
cgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h

driver/level3/level3.c

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
#if !defined(XDOUBLE) || !defined(QUAD_PRECISION)
4444
#ifndef COMPLEX
4545
#define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \
46-
GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \
47-
BETA[0], NULL, 0, NULL, 0, \
48-
(FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC)
46+
do { \
47+
GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \
48+
BETA[0], NULL, 0, NULL, 0, \
49+
(FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC); \
50+
} while (0)
4951
#else
5052
#define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \
5153
GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \
@@ -189,7 +191,11 @@
189191
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
190192
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
191193
BLASLONG k, lda, ldb, ldc;
194+
#if defined(BUILD_BFLOAT16_ONLY)
195+
float *alpha, *beta;
196+
#else
192197
FLOAT *alpha, *beta;
198+
#endif
193199
IFLOAT *a, *b;
194200
FLOAT *c;
195201
BLASLONG m_from, m_to, n_from, n_to;
@@ -224,8 +230,14 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
224230
ldb = LDB;
225231
ldc = LDC;
226232

233+
#if defined(BUILD_BFLOAT16_ONLY)
234+
alpha = (float *)args -> alpha;
235+
beta = (float *)args -> beta;
236+
#else
227237
alpha = (FLOAT *)args -> alpha;
228238
beta = (FLOAT *)args -> beta;
239+
#endif
240+
229241

230242
m_from = 0;
231243
m_to = M;

driver/level3/level3_thread.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,11 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
239239
BLASLONG k, lda, ldb, ldc;
240240
BLASLONG m_from, m_to, n_from, n_to;
241241

242+
#if defined(BUILD_BFLOAT16_ONLY)
243+
float *alpha, *beta;
244+
#else
242245
FLOAT *alpha, *beta;
246+
#endif
243247
IFLOAT *a, *b;
244248
FLOAT *c;
245249
job_t *job = (job_t *)args -> common;
@@ -277,8 +281,14 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
277281
ldb = LDB;
278282
ldc = LDC;
279283

284+
#if defined(BUILD_BFLOAT16_ONLY)
285+
alpha = (float *)args -> alpha;
286+
beta = (float *)args -> beta;
287+
#else
280288
alpha = (FLOAT *)args -> alpha;
281289
beta = (FLOAT *)args -> beta;
290+
#endif
291+
282292

283293
/* Initialize 2D CPU distribution */
284294
nthreads_m = args -> nthreads;

0 commit comments

Comments
 (0)