Skip to content

Commit bf2c03e

Browse files
committed
Optimize fp32 RVV distance kernels
Signed-off-by: ihb2032 <hebome@foxmail.com>
1 parent 3b3f6a5 commit bf2c03e

1 file changed

Lines changed: 41 additions & 113 deletions

File tree

src/simd/distances_rvv.cc

Lines changed: 41 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -54,106 +54,49 @@ bf16_float_rvv(vfloat32m1_t f, size_t vl) {
5454
// =================== float distances ===================
5555
float
5656
fvec_inner_product_rvv(const float* x, const float* y, size_t d) {
57-
size_t vlmax = __riscv_vsetvlmax_e32m2(); // Use m2 to support 4-way parallelism
58-
59-
// 4 accumulators
60-
vfloat32m2_t vacc0 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
61-
vfloat32m2_t vacc1 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
62-
vfloat32m2_t vacc2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
63-
vfloat32m2_t vacc3 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
57+
const size_t vlmax = __riscv_vsetvlmax_e32m8();
58+
vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.0f, vlmax);
6459

6560
size_t offset = 0;
61+
while (offset < d) {
62+
const size_t vl = __riscv_vsetvl_e32m8(d - offset);
6663

67-
// 4-way unrolled loop
68-
while (d >= 4 * vlmax) {
69-
size_t vl = vlmax;
70-
71-
vfloat32m2_t vx0 = __riscv_vle32_v_f32m2(x + offset, vl);
72-
vfloat32m2_t vy0 = __riscv_vle32_v_f32m2(y + offset, vl);
73-
vfloat32m2_t vx1 = __riscv_vle32_v_f32m2(x + offset + vl, vl);
74-
vfloat32m2_t vy1 = __riscv_vle32_v_f32m2(y + offset + vl, vl);
75-
vfloat32m2_t vx2 = __riscv_vle32_v_f32m2(x + offset + 2 * vl, vl);
76-
vfloat32m2_t vy2 = __riscv_vle32_v_f32m2(y + offset + 2 * vl, vl);
77-
vfloat32m2_t vx3 = __riscv_vle32_v_f32m2(x + offset + 3 * vl, vl);
78-
vfloat32m2_t vy3 = __riscv_vle32_v_f32m2(y + offset + 3 * vl, vl);
79-
80-
// Parallel FMACC operations
81-
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx0, vy0, vl);
82-
vacc1 = __riscv_vfmacc_vv_f32m2_tu(vacc1, vx1, vy1, vl);
83-
vacc2 = __riscv_vfmacc_vv_f32m2_tu(vacc2, vx2, vy2, vl);
84-
vacc3 = __riscv_vfmacc_vv_f32m2_tu(vacc3, vx3, vy3, vl);
64+
vfloat32m8_t vx = __riscv_vle32_v_f32m8(x + offset, vl);
65+
vfloat32m8_t vy = __riscv_vle32_v_f32m8(y + offset, vl);
8566

86-
offset += 4 * vl;
87-
d -= 4 * vl;
88-
}
89-
90-
// Merge accumulators
91-
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc1, vlmax);
92-
vacc2 = __riscv_vfadd_vv_f32m2(vacc2, vacc3, vlmax);
93-
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc2, vlmax);
94-
95-
// Handle remaining elements
96-
while (d > 0) {
97-
size_t vl = __riscv_vsetvl_e32m2(d);
98-
vfloat32m2_t vx = __riscv_vle32_v_f32m2(x + offset, vl);
99-
vfloat32m2_t vy = __riscv_vle32_v_f32m2(y + offset, vl);
100-
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx, vy, vl);
67+
acc = __riscv_vfmacc_vv_f32m8_tu(acc, vx, vy, vl);
10168

10269
offset += vl;
103-
d -= vl;
10470
}
10571

106-
// Final reduction
107-
vfloat32m1_t sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
108-
sum_scalar = __riscv_vfredusum_vs_f32m2_f32m1(vacc0, sum_scalar, vlmax);
72+
vfloat32m1_t sum = __riscv_vfmv_s_f_f32m1(0.0f, 1);
73+
sum = __riscv_vfredusum_vs_f32m8_f32m1(acc, sum, vlmax);
10974

110-
return __riscv_vfmv_f_s_f32m1_f32(sum_scalar);
75+
return __riscv_vfmv_f_s_f32m1_f32(sum);
11176
}
11277

11378
float
11479
fvec_L2sqr_rvv(const float* x, const float* y, size_t d) {
115-
size_t vlmax = __riscv_vsetvlmax_e32m2();
116-
vfloat32m2_t vacc0 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
117-
vfloat32m2_t vacc1 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
118-
vfloat32m2_t vacc2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
119-
vfloat32m2_t vacc3 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
80+
const size_t vlmax = __riscv_vsetvlmax_e32m8();
81+
vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.0f, vlmax);
82+
12083
size_t offset = 0;
121-
while (d >= 4 * vlmax) {
122-
size_t vl = vlmax;
123-
vfloat32m2_t vx0 = __riscv_vle32_v_f32m2(x + offset, vl);
124-
vfloat32m2_t vy0 = __riscv_vle32_v_f32m2(y + offset, vl);
125-
vfloat32m2_t vx1 = __riscv_vle32_v_f32m2(x + offset + vl, vl);
126-
vfloat32m2_t vy1 = __riscv_vle32_v_f32m2(y + offset + vl, vl);
127-
vfloat32m2_t vx2 = __riscv_vle32_v_f32m2(x + offset + 2 * vl, vl);
128-
vfloat32m2_t vy2 = __riscv_vle32_v_f32m2(y + offset + 2 * vl, vl);
129-
vfloat32m2_t vx3 = __riscv_vle32_v_f32m2(x + offset + 3 * vl, vl);
130-
vfloat32m2_t vy3 = __riscv_vle32_v_f32m2(y + offset + 3 * vl, vl);
131-
vfloat32m2_t vtmp0 = __riscv_vfsub_vv_f32m2(vx0, vy0, vl);
132-
vfloat32m2_t vtmp1 = __riscv_vfsub_vv_f32m2(vx1, vy1, vl);
133-
vfloat32m2_t vtmp2 = __riscv_vfsub_vv_f32m2(vx2, vy2, vl);
134-
vfloat32m2_t vtmp3 = __riscv_vfsub_vv_f32m2(vx3, vy3, vl);
135-
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vtmp0, vtmp0, vl);
136-
vacc1 = __riscv_vfmacc_vv_f32m2_tu(vacc1, vtmp1, vtmp1, vl);
137-
vacc2 = __riscv_vfmacc_vv_f32m2_tu(vacc2, vtmp2, vtmp2, vl);
138-
vacc3 = __riscv_vfmacc_vv_f32m2_tu(vacc3, vtmp3, vtmp3, vl);
139-
offset += 4 * vl;
140-
d -= 4 * vl;
141-
}
142-
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc1, vlmax);
143-
vacc2 = __riscv_vfadd_vv_f32m2(vacc2, vacc3, vlmax);
144-
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc2, vlmax);
145-
while (d > 0) {
146-
size_t vl = __riscv_vsetvl_e32m2(d);
147-
vfloat32m2_t vx = __riscv_vle32_v_f32m2(x + offset, vl);
148-
vfloat32m2_t vy = __riscv_vle32_v_f32m2(y + offset, vl);
149-
vfloat32m2_t vtmp = __riscv_vfsub_vv_f32m2(vx, vy, vl);
150-
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vtmp, vtmp, vl);
84+
while (offset < d) {
85+
const size_t vl = __riscv_vsetvl_e32m8(d - offset);
86+
87+
vfloat32m8_t vx = __riscv_vle32_v_f32m8(x + offset, vl);
88+
vfloat32m8_t vy = __riscv_vle32_v_f32m8(y + offset, vl);
89+
90+
vx = __riscv_vfsub_vv_f32m8(vx, vy, vl);
91+
acc = __riscv_vfmacc_vv_f32m8_tu(acc, vx, vx, vl);
92+
15193
offset += vl;
152-
d -= vl;
15394
}
154-
vfloat32m1_t sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
155-
sum_scalar = __riscv_vfredusum_vs_f32m2_f32m1(vacc0, sum_scalar, vlmax);
156-
return __riscv_vfmv_f_s_f32m1_f32(sum_scalar);
95+
96+
vfloat32m1_t sum = __riscv_vfmv_s_f_f32m1(0.0f, 1);
97+
sum = __riscv_vfredusum_vs_f32m8_f32m1(acc, sum, vlmax);
98+
99+
return __riscv_vfmv_f_s_f32m1_f32(sum);
157100
}
158101

159102
float
@@ -250,38 +193,23 @@ fvec_Linf_rvv(const float* x, const float* y, size_t d) {
250193

251194
float
252195
fvec_norm_L2sqr_rvv(const float* x, size_t d) {
253-
size_t vlmax = __riscv_vsetvlmax_e32m2();
254-
vfloat32m2_t vacc0 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
255-
vfloat32m2_t vacc1 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
256-
vfloat32m2_t vacc2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
257-
vfloat32m2_t vacc3 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
196+
const size_t vlmax = __riscv_vsetvlmax_e32m8();
197+
vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.0f, vlmax);
198+
258199
size_t offset = 0;
259-
while (d >= 4 * vlmax) {
260-
size_t vl = vlmax;
261-
vfloat32m2_t vx0 = __riscv_vle32_v_f32m2(x + offset, vl);
262-
vfloat32m2_t vx1 = __riscv_vle32_v_f32m2(x + offset + vl, vl);
263-
vfloat32m2_t vx2 = __riscv_vle32_v_f32m2(x + offset + 2 * vl, vl);
264-
vfloat32m2_t vx3 = __riscv_vle32_v_f32m2(x + offset + 3 * vl, vl);
265-
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx0, vx0, vl);
266-
vacc1 = __riscv_vfmacc_vv_f32m2_tu(vacc1, vx1, vx1, vl);
267-
vacc2 = __riscv_vfmacc_vv_f32m2_tu(vacc2, vx2, vx2, vl);
268-
vacc3 = __riscv_vfmacc_vv_f32m2_tu(vacc3, vx3, vx3, vl);
269-
offset += 4 * vl;
270-
d -= 4 * vl;
271-
}
272-
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc1, vlmax);
273-
vacc2 = __riscv_vfadd_vv_f32m2(vacc2, vacc3, vlmax);
274-
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc2, vlmax);
275-
while (d > 0) {
276-
size_t vl = __riscv_vsetvl_e32m2(d);
277-
vfloat32m2_t vx = __riscv_vle32_v_f32m2(x + offset, vl);
278-
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx, vx, vl);
200+
while (offset < d) {
201+
const size_t vl = __riscv_vsetvl_e32m8(d - offset);
202+
203+
vfloat32m8_t vx = __riscv_vle32_v_f32m8(x + offset, vl);
204+
acc = __riscv_vfmacc_vv_f32m8_tu(acc, vx, vx, vl);
205+
279206
offset += vl;
280-
d -= vl;
281207
}
282-
vfloat32m1_t sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
283-
sum_scalar = __riscv_vfredusum_vs_f32m2_f32m1(vacc0, sum_scalar, vlmax);
284-
return __riscv_vfmv_f_s_f32m1_f32(sum_scalar);
208+
209+
vfloat32m1_t sum = __riscv_vfmv_s_f_f32m1(0.0f, 1);
210+
sum = __riscv_vfredusum_vs_f32m8_f32m1(acc, sum, vlmax);
211+
212+
return __riscv_vfmv_f_s_f32m1_f32(sum);
285213
}
286214

287215
void

0 commit comments

Comments
 (0)