Skip to content

Commit 5e71cef

Browse files
committed
Optimize gemv_n_sve kernel
1 parent ef9e3f7 commit 5e71cef

2 files changed

Lines changed: 78 additions & 13 deletions

File tree

kernel/arm64/KERNEL.ARMV8SVE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
23
CSUMKERNEL = csum_thunderx2t99.c
34
ZSUMKERNEL = zsum_thunderx2t99.c
45
SAMINKERNEL = ../arm/amin.c
@@ -74,7 +75,7 @@ DSCALKERNEL = scal.S
7475
CSCALKERNEL = zscal.S
7576
ZSCALKERNEL = zscal.S
7677

77-
SGEMVNKERNEL = gemv_n.S
78+
SGEMVNKERNEL = gemv_n_sve.c
7879
DGEMVNKERNEL = gemv_n.S
7980
CGEMVNKERNEL = zgemv_n.S
8081
ZGEMVNKERNEL = zgemv_n.S

kernel/arm64/gemv_n_sve.c

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************
2-
Copyright (c) 2024, The OpenBLAS Project
2+
Copyright (c) 2024-2025, The OpenBLAS Project
33
All rights reserved.
44
55
Redistribution and use in source and binary forms, with or without
@@ -57,25 +57,89 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
5757

5858
ix = 0;
5959
a_ptr = a;
60-
6160
if (inc_y == 1) {
61+
BLASLONG width = n / 3;
6262
uint64_t sve_size = SV_COUNT();
63-
for (j = 0; j < n; j++) {
64-
SV_TYPE temp_vec = SV_DUP(alpha * x[ix]);
63+
svbool_t pg_true = SV_TRUE();
64+
svbool_t pg = SV_WHILE(0, m % sve_size);
65+
66+
FLOAT *a0_ptr = a + lda * width * 0;
67+
FLOAT *a1_ptr = a + lda * width * 1;
68+
FLOAT *a2_ptr = a + lda * width * 2;
69+
70+
for (j = 0; j < width; j++) {
6571
i = 0;
66-
svbool_t pg = SV_WHILE(i, m);
67-
while (svptest_any(SV_TRUE(), pg)) {
68-
SV_TYPE a_vec = svld1(pg, a_ptr + i);
72+
while ((i + sve_size * 1 - 1) < m) {
73+
ix = j * inc_x;
74+
75+
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
76+
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
77+
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);
78+
79+
SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i);
80+
SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i);
81+
SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i);
82+
83+
SV_TYPE y_vec = svld1(pg_true, y + i);
84+
y_vec = svmla_lane(y_vec, a00_vec, x0_vec, 0);
85+
y_vec = svmla_lane(y_vec, a01_vec, x1_vec, 0);
86+
y_vec = svmla_lane(y_vec, a02_vec, x2_vec, 0);
87+
88+
svst1(pg_true, y + i, y_vec);
89+
90+
i += sve_size * 1;
91+
}
92+
93+
if (i < m) {
94+
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
95+
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
96+
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);
97+
98+
SV_TYPE a00_vec = svld1(pg, a0_ptr + i);
99+
SV_TYPE a01_vec = svld1(pg, a1_ptr + i);
100+
SV_TYPE a02_vec = svld1(pg, a2_ptr + i);
101+
69102
SV_TYPE y_vec = svld1(pg, y + i);
70-
y_vec = svmla_x(pg, y_vec, temp_vec, a_vec);
103+
y_vec = svmla_m(pg, y_vec, a00_vec, x0_vec);
104+
y_vec = svmla_m(pg, y_vec, a01_vec, x1_vec);
105+
y_vec = svmla_m(pg, y_vec, a02_vec, x2_vec);
106+
107+
ix += inc_x;
108+
71109
svst1(pg, y + i, y_vec);
72-
i += sve_size;
73-
pg = SV_WHILE(i, m);
74110
}
111+
112+
a0_ptr += lda;
113+
a1_ptr += lda;
114+
a2_ptr += lda;
115+
}
116+
117+
a_ptr = a2_ptr;
118+
for (j = width * 3; j < n; j++) {
119+
ix = j * inc_x;
120+
i = 0;
121+
while ((i + sve_size * 1 - 1) < m) {
122+
SV_TYPE y_vec = svld1(pg_true, y + i);
123+
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
124+
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
125+
y_vec = svmla_x(pg_true, y_vec, a_vec, x_vec);
126+
svst1(pg_true, y + i, y_vec);
127+
i += sve_size * 1;
128+
}
129+
130+
if (i < m) {
131+
SV_TYPE y_vec = svld1(pg, y + i);
132+
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
133+
SV_TYPE a_vec = svld1(pg, a_ptr + i);
134+
y_vec = svmla_m(pg, y_vec, a_vec, x_vec);
135+
svst1(pg, y + i, y_vec);
136+
}
137+
75138
a_ptr += lda;
76139
ix += inc_x;
77140
}
78-
return(0);
141+
142+
return (0);
79143
}
80144

81145
for (j = 0; j < n; j++) {
@@ -89,4 +153,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
89153
ix += inc_x;
90154
}
91155
return (0);
92-
}
156+
}

0 commit comments

Comments
 (0)