@@ -18,15 +18,15 @@ static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
1818 CeedTransposeMode t_mode , const CeedInt add , const CeedScalar * restrict u , CeedScalar * restrict v ) {
1919 if (C == 1 ) {
2020 // Build or query the required kernel
21- const int flags_t = LIBXSMM_GEMM_FLAGS (!t_mode ? 'T' : 'N' , 'N' );
22- const int flags_ab = (!add ) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE ;
23- const int flags = (flags_t | flags_ab );
24- const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64 )
25- ? libxsmm_create_gemm_shape (J , A , B , !t_mode ? B : J , B , J , LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 ,
26- LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 )
27- : libxsmm_create_gemm_shape (J , A , B , !t_mode ? B : J , B , J , LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 ,
28- LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 );
29- const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm (gemm_shape , (libxsmm_bitfield )(flags ), (libxsmm_bitfield )LIBXSMM_GEMM_PREFETCH_NONE );
21+ const int flags_t = LIBXSMM_GEMM_FLAGS (!t_mode ? 'T' : 'N' , 'N' );
22+ const int flags_ab = (!add ) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE ;
23+ const int flags = (flags_t | flags_ab );
24+ const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64 )
25+ ? libxsmm_create_gemm_shape (J , A , B , !t_mode ? B : J , B , J , LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 ,
26+ LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 )
27+ : libxsmm_create_gemm_shape (J , A , B , !t_mode ? B : J , B , J , LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 ,
28+ LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 );
29+ const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm (gemm_shape , (libxsmm_bitfield )(flags ), (libxsmm_bitfield )LIBXSMM_GEMM_PREFETCH_NONE );
3030 libxsmm_gemm_param gemm_param ;
3131
3232 CeedCheck (kernel , CeedTensorContractReturnCeed (contract ), CEED_ERROR_BACKEND , "LIBXSMM kernel failed to build." );
@@ -38,15 +38,15 @@ static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
3838 kernel (& gemm_param );
3939 } else {
4040 // Build or query the required kernel
41- const int flags_t = LIBXSMM_GEMM_FLAGS ('N' , t_mode ? 'T' : 'N' );
42- const int flags_ab = (!add ) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE ;
43- const int flags = (flags_t | flags_ab );
44- const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64 )
45- ? libxsmm_create_gemm_shape (C , J , B , C , !t_mode ? B : J , C , LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 ,
46- LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 )
47- : libxsmm_create_gemm_shape (C , J , B , C , !t_mode ? B : J , C , LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 ,
48- LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 );
49- const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm (gemm_shape , (libxsmm_bitfield )(flags ), (libxsmm_bitfield )LIBXSMM_GEMM_PREFETCH_NONE );
41+ const int flags_t = LIBXSMM_GEMM_FLAGS ('N' , t_mode ? 'T' : 'N' );
42+ const int flags_ab = (!add ) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE ;
43+ const int flags = (flags_t | flags_ab );
44+ const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64 )
45+ ? libxsmm_create_gemm_shape (C , J , B , C , !t_mode ? B : J , C , LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 ,
46+ LIBXSMM_DATATYPE_F64 , LIBXSMM_DATATYPE_F64 )
47+ : libxsmm_create_gemm_shape (C , J , B , C , !t_mode ? B : J , C , LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 ,
48+ LIBXSMM_DATATYPE_F32 , LIBXSMM_DATATYPE_F32 );
49+ const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm (gemm_shape , (libxsmm_bitfield )(flags ), (libxsmm_bitfield )LIBXSMM_GEMM_PREFETCH_NONE );
5050 libxsmm_gemm_param gemm_param ;
5151
5252 CeedCheck (kernel , CeedTensorContractReturnCeed (contract ), CEED_ERROR_BACKEND , "LIBXSMM kernel failed to build." );
0 commit comments