Skip to content

Commit 7f89c6f

Browse files
authored
smh-based direct sgemm currently requires leading dimensions to be same as matrix dimension
1 parent 1ee8879 commit 7f89c6f

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

interface/gemm.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ void NAME(char *TRANSA, char *TRANSB,
266266

267267
int transa, transb, nrowa, nrowb;
268268
blasint info;
269+
int order = -1;
269270

270271
char transA, transB;
271272
IFLOAT *buffer;
@@ -557,15 +558,16 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
557558
if (strcmp(gotoblas_corename(), "armv9sme") == 0 || strcmp(gotoblas_corename(), "vortexm4") == 0)
558559
// if (support_sme1())
559560
#endif
560-
if (order == CblasRowMajor && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) {
561+
if (order == CblasRowMajor && m==lda && n ==ldb && k==ldc && beta == 0 && alpha == 1.0 && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) {
561562
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
562563
return;
563564
}
564565
else
565-
if (order == CblasRowMajor && beta != 0. && (!(alpha==1.&&beta==1.)) && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) {
566+
if (order == CblasRowMajor && m==lda && n==ldb && k==ldc && TransA == CblasNoTrans && TransB == CblasNoTrans&& SGEMM_DIRECT_PERFORMANT(m,n,k)) {
566567
SGEMM_DIRECT_ALPHA_BETA(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
567568
return;
568569
}
570+
569571
#endif
570572
#endif
571573

@@ -587,9 +589,6 @@ else
587589

588590
if ((args.m == 0) || (args.n == 0)) return;
589591

590-
591-
592-
593592
#if 0
594593
fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
595594
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
@@ -626,6 +625,7 @@ else
626625
}
627626
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1));
628627
if (is_efficient_gemv) {
628+
fprintf(stderr,"gemv_forwarding\n");
629629
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
630630
return;
631631
}
@@ -649,6 +649,7 @@ else
649649
}
650650
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1));
651651
if (is_efficient_gemv) {
652+
fprintf(stderr,"gemv_forwarding\n");
652653
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
653654
return;
654655
}

0 commit comments

Comments
 (0)