Skip to content

Commit 835dd68

Browse files
authored
Replace the interleaved gemvn_sve with a sequential version
1 parent 93515c2 commit 835dd68

1 file changed

Lines changed: 60 additions & 62 deletions

File tree

kernel/arm64/gemv_n_sve_v1x3.c

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ modification, are permitted provided that the following conditions are
1313
notice, this list of conditions and the following disclaimer in
1414
the documentation and/or other materials provided with the
1515
distribution.
16-
3. Neither the name of the OpenBLAS project nor the names of
17-
its contributors may be used to endorse or promote products
18-
derived from this software without specific prior written
16+
3. Neither the name of the OpenBLAS project nor the names of
17+
its contributors may be used to endorse or promote products
18+
derived from this software without specific prior written
1919
permission.
2020
2121
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
@@ -52,96 +52,93 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
5252
BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y,
5353
FLOAT *buffer)
5454
{
55-
BLASLONG i, j;
56-
BLASLONG ix = 0;
57-
BLASLONG iy;
58-
FLOAT *a_ptr = a;
55+
BLASLONG i;
56+
BLASLONG ix, iy;
57+
BLASLONG j;
58+
FLOAT *a_ptr;
5959
FLOAT temp;
6060

61+
a_ptr = a;
62+
ix = 0;
63+
6164
if (inc_y == 1) {
62-
BLASLONG width = n / 3; // Only process full 3-column blocks
6365
BLASLONG sve_size = SV_COUNT();
64-
svbool_t pg_full = SV_TRUE();
65-
svbool_t pg_tail = SV_WHILE(0, m % sve_size);
66-
67-
FLOAT *a0_ptr = a_ptr + lda * width * 0;
68-
FLOAT *a1_ptr = a_ptr + lda * width * 1;
69-
FLOAT *a2_ptr = a_ptr + lda * width * 2;
66+
svbool_t pg_true = SV_TRUE();
7067

71-
FLOAT *x0_ptr = x + inc_x * width * 0;
72-
FLOAT *x1_ptr = x + inc_x * width * 1;
73-
FLOAT *x2_ptr = x + inc_x * width * 2;
68+
/* Process 3 consecutive columns at a time: (j, j+1, j+2) */
69+
for (j = 0; j + 2 < n; j += 3) {
70+
SV_TYPE temp0_vec = SV_DUP(alpha * x[ix]);
71+
SV_TYPE temp1_vec = SV_DUP(alpha * x[ix + inc_x]);
72+
SV_TYPE temp2_vec = SV_DUP(alpha * x[ix + inc_x * 2]);
7473

75-
for (j = 0; j < width; j++) {
76-
SV_TYPE temp0_vec = SV_DUP(alpha * x0_ptr[ix]);
77-
SV_TYPE temp1_vec = SV_DUP(alpha * x1_ptr[ix]);
78-
SV_TYPE temp2_vec = SV_DUP(alpha * x2_ptr[ix]);
74+
FLOAT *a0 = a_ptr;
75+
FLOAT *a1 = a_ptr + lda;
76+
FLOAT *a2 = a_ptr + lda * 2;
7977

8078
i = 0;
8179
while ((i + sve_size - 1) < m) {
82-
SV_TYPE y0_vec = svld1(pg_full, y + i);
80+
SV_TYPE y0_vec = svld1(pg_true, y + i);
8381

84-
SV_TYPE a00_vec = svld1(pg_full, a0_ptr + i);
85-
SV_TYPE a01_vec = svld1(pg_full, a1_ptr + i);
86-
SV_TYPE a02_vec = svld1(pg_full, a2_ptr + i);
82+
SV_TYPE a00_vec = svld1(pg_true, a0 + i);
83+
SV_TYPE a01_vec = svld1(pg_true, a1 + i);
84+
SV_TYPE a02_vec = svld1(pg_true, a2 + i);
8785

88-
y0_vec = svmla_x(pg_full, y0_vec, temp0_vec, a00_vec);
89-
y0_vec = svmla_x(pg_full, y0_vec, temp1_vec, a01_vec);
90-
y0_vec = svmla_x(pg_full, y0_vec, temp2_vec, a02_vec);
86+
y0_vec = svmla_m(pg_true, y0_vec, temp0_vec, a00_vec);
87+
y0_vec = svmla_m(pg_true, y0_vec, temp1_vec, a01_vec);
88+
y0_vec = svmla_m(pg_true, y0_vec, temp2_vec, a02_vec);
9189

92-
svst1(pg_full, y + i, y0_vec);
90+
svst1(pg_true, y + i, y0_vec);
9391
i += sve_size;
9492
}
9593

9694
if (i < m) {
97-
SV_TYPE y0_vec = svld1(pg_tail, y + i);
95+
svbool_t pg = SV_WHILE(i, m);
96+
97+
SV_TYPE y0_vec = svld1(pg, y + i);
9898

99-
SV_TYPE a00_vec = svld1(pg_tail, a0_ptr + i);
100-
SV_TYPE a01_vec = svld1(pg_tail, a1_ptr + i);
101-
SV_TYPE a02_vec = svld1(pg_tail, a2_ptr + i);
99+
SV_TYPE a00_vec = svld1(pg, a0 + i);
100+
SV_TYPE a01_vec = svld1(pg, a1 + i);
101+
SV_TYPE a02_vec = svld1(pg, a2 + i);
102102

103-
y0_vec = svmla_m(pg_tail, y0_vec, temp0_vec, a00_vec);
104-
y0_vec = svmla_m(pg_tail, y0_vec, temp1_vec, a01_vec);
105-
y0_vec = svmla_m(pg_tail, y0_vec, temp2_vec, a02_vec);
103+
y0_vec = svmla_m(pg, y0_vec, temp0_vec, a00_vec);
104+
y0_vec = svmla_m(pg, y0_vec, temp1_vec, a01_vec);
105+
y0_vec = svmla_m(pg, y0_vec, temp2_vec, a02_vec);
106106

107-
svst1(pg_tail, y + i, y0_vec);
107+
svst1(pg, y + i, y0_vec);
108108
}
109-
a0_ptr += lda;
110-
a1_ptr += lda;
111-
a2_ptr += lda;
112-
ix += inc_x;
109+
110+
a_ptr += lda * 3;
111+
ix += inc_x * 3;
113112
}
114-
// Handle remaining n % 3 columns
115-
for (j = width * 3; j < n; j++) {
116-
FLOAT *a_col = a + j * lda;
117-
temp = alpha * x[j * inc_x];
118-
SV_TYPE temp_vec = SV_DUP(temp);
113+
114+
/* Cleanup: remaining 1 or 2 columns */
115+
for (; j < n; j++) {
116+
SV_TYPE temp_vec = SV_DUP(alpha * x[ix]);
119117

120118
i = 0;
121119
while ((i + sve_size - 1) < m) {
122-
SV_TYPE y_vec = svld1(pg_full, y + i);
123-
124-
SV_TYPE a_vec = svld1(pg_full, a_col + i);
125-
126-
y_vec = svmla_x(pg_full, y_vec, temp_vec, a_vec);
127-
128-
svst1(pg_full, y + i, y_vec);
120+
SV_TYPE y_vec = svld1(pg_true, y + i);
121+
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
122+
y_vec = svmla_m(pg_true, y_vec, temp_vec, a_vec);
123+
svst1(pg_true, y + i, y_vec);
129124
i += sve_size;
130125
}
131-
if (i < m) {
132-
SV_TYPE y_vec = svld1(pg_tail, y + i);
133-
134-
SV_TYPE a_vec = svld1(pg_tail, a_col + i);
135126

136-
y_vec = svmla_m(pg_tail, y_vec, temp_vec, a_vec);
137-
138-
svst1(pg_tail, y + i, y_vec);
127+
if (i < m) {
128+
svbool_t pg = SV_WHILE(i, m);
129+
SV_TYPE y_vec = svld1(pg, y + i);
130+
SV_TYPE a_vec = svld1(pg, a_ptr + i);
131+
y_vec = svmla_m(pg, y_vec, temp_vec, a_vec);
132+
svst1(pg, y + i, y_vec);
139133
}
134+
135+
a_ptr += lda;
136+
ix += inc_x;
140137
}
141-
return(0);
138+
139+
return (0);
142140
}
143141

144-
// Fallback scalar loop
145142
for (j = 0; j < n; j++) {
146143
temp = alpha * x[ix];
147144
iy = 0;
@@ -154,3 +151,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
154151
}
155152
return (0);
156153
}
154+

0 commit comments

Comments
 (0)