Skip to content

Commit 010f24f

Browse files
committed
Better K.
1 parent b0ee407 commit 010f24f

1 file changed

Lines changed: 23 additions & 29 deletions

File tree

kernel/riscv64/sgemm_kernel_16x8_zvl256b.c

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -590,15 +590,10 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
590590
A3 += 1;
591591
}
592592
#endif
593+
K--;
593594
}
594595

595-
#ifdef GEMM_RIGHT_CHUNK
596-
BLASLONG k = (M <= 8) ? 0 : 1;
597-
#else
598-
BLASLONG k = 1;
599-
#endif
600-
601-
for (; k < K; k++) {
596+
while (K--) {
602597
if (!S2) {
603598
B0 = __riscv_vle32_v_f32m1(B, N);
604599
}
@@ -987,7 +982,7 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
987982
}
988983
#endif
989984

990-
for (BLASLONG k = 1; k < K; k++) {
985+
while (--K) {
991986
if (S2 || S3) {
992987
result03 = __riscv_vle32_v_f32m1(A0, 8);
993988
}
@@ -1506,12 +1501,21 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
15061501
B04 = B + ((N & 6) * K);
15071502
}
15081503
#endif
1509-
#ifdef GEMM_BOTTOM_CHUNK
15101504
FLOAT K2;
1511-
if (N <= 4) {
1512-
K2 = K;
1513-
}
1505+
#ifdef GEMM_BOTTOM_CHUNK
1506+
FLOAT K3;
1507+
if (N == 1) {
1508+
K3 = (K / 8);
1509+
K &= 7;
1510+
} else if (N <= 4) {
1511+
K3 = (K / 2);
1512+
K &= 1;
1513+
} else
15141514
#endif
1515+
{
1516+
K--;
1517+
}
1518+
K2 = K;
15151519
do {
15161520
FLOAT B0, B1, B2, B3, B4, B5, B6;
15171521
#ifdef GEMM_NEW_PACKING
@@ -1538,12 +1542,9 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
15381542
vfloat32m1_t A2, A3, A4, A5, A6, A7;
15391543
vfloat32m1_t resultE, resultF;
15401544
FLOAT B7;
1541-
if (N <= 4) {
1542-
K = K2;
1543-
}
15441545

15451546
if (N == 1) {
1546-
if (K >= 8) {
1547+
if (K3) {
15471548
vfloat32m8_t A01 = __riscv_vle32_v_f32m8(A, 8 * 8);
15481549
A0 = __riscv_vget_v_f32m8_f32m1(A01, 0);
15491550
A1 = __riscv_vget_v_f32m8_f32m1(A01, 1);
@@ -1605,7 +1606,7 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
16051606
resultE = __riscv_vfmul_vf_f32m1(A6, B7, 8);
16061607
resultF = __riscv_vfmul_vf_f32m1(A7, B7, 8);
16071608

1608-
for (BLASLONG k = (K / 8); --k; ) {
1609+
for (BLASLONG k = K3; --k; ) {
16091610
A01 = __riscv_vle32_v_f32m8(A, 8 * 8);
16101611
A0 = __riscv_vget_v_f32m8_f32m1(A01, 0);
16111612
A1 = __riscv_vget_v_f32m8_f32m1(A01, 1);
@@ -1682,14 +1683,12 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
16821683
result1 = __riscv_vfadd_vv_f32m1(result1, result5, 8);
16831684
resultC = __riscv_vfadd_vv_f32m1(resultC, result0, 8);
16841685
resultD = __riscv_vfadd_vv_f32m1(resultD, result1, 8);
1685-
1686-
K &= 7;
16871686
} else {
16881687
resultC = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, 8));
16891688
resultD = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, 8));
16901689
}
16911690
} else if (N <= 4) {
1692-
if (K >= 2) {
1691+
if (K3) {
16931692
vfloat32m4_t A01 = __riscv_vle32_v_f32m4(A, 4 * 8);
16941693
A0 = __riscv_vget_v_f32m4_f32m1(A01, 0);
16951694
A1 = __riscv_vget_v_f32m4_f32m1(A01, 1);
@@ -1772,7 +1771,7 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
17721771
result5 = __riscv_vfmul_vf_f32m1(A3, B6, 8);
17731772
}
17741773

1775-
for (BLASLONG k = (K / 2); --k; ) {
1774+
for (BLASLONG k = K3; --k; ) {
17761775
A01 = __riscv_vle32_v_f32m4(A, 4 * 8);
17771776
A0 = __riscv_vget_v_f32m4_f32m1(A01, 0);
17781777
A1 = __riscv_vget_v_f32m4_f32m1(A01, 1);
@@ -1876,8 +1875,6 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
18761875
resultC = __riscv_vfadd_vv_f32m1(resultC, result4, 8);
18771876
resultD = __riscv_vfadd_vv_f32m1(resultD, result5, 8);
18781877
}
1879-
1880-
K &= 1;
18811878
} else {
18821879
if (N == 4) {
18831880
result0 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, 8));
@@ -1961,12 +1958,7 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
19611958
}
19621959
}
19631960

1964-
#ifdef GEMM_BOTTOM_CHUNK
1965-
BLASLONG k = (N <= 4) ? 0 : 1;
1966-
#else
1967-
BLASLONG k = 1;
1968-
#endif
1969-
for (; k < K; k++) {
1961+
while (K--) {
19701962
if (N & 4) {
19711963
B0 = B00[0];
19721964
B1 = B00[1];
@@ -2120,6 +2112,7 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
21202112
}
21212113

21222114
C = C0 + 16;
2115+
K = K2;
21232116
} while (--M);
21242117
}
21252118

@@ -2221,6 +2214,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
22212214
BLASLONG n_top = 0;
22222215
const BLASLONG m_edge = M & 15;
22232216
const bool S = (M == (ldc & 0xF));
2217+
if (K <= 0) return 0;
22242218

22252219
// -- MAIN PASS
22262220

0 commit comments

Comments
 (0)