Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 41 additions & 113 deletions src/simd/distances_rvv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,106 +54,49 @@ bf16_float_rvv(vfloat32m1_t f, size_t vl) {
// =================== float distances ===================
float
fvec_inner_product_rvv(const float* x, const float* y, size_t d) {
size_t vlmax = __riscv_vsetvlmax_e32m2(); // Use m2 to support 4-way parallelism

// 4 accumulators
vfloat32m2_t vacc0 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc1 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc3 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
const size_t vlmax = __riscv_vsetvlmax_e32m8();
vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.0f, vlmax);

size_t offset = 0;
while (offset < d) {
const size_t vl = __riscv_vsetvl_e32m8(d - offset);

// 4-way unrolled loop
while (d >= 4 * vlmax) {
size_t vl = vlmax;

vfloat32m2_t vx0 = __riscv_vle32_v_f32m2(x + offset, vl);
vfloat32m2_t vy0 = __riscv_vle32_v_f32m2(y + offset, vl);
vfloat32m2_t vx1 = __riscv_vle32_v_f32m2(x + offset + vl, vl);
vfloat32m2_t vy1 = __riscv_vle32_v_f32m2(y + offset + vl, vl);
vfloat32m2_t vx2 = __riscv_vle32_v_f32m2(x + offset + 2 * vl, vl);
vfloat32m2_t vy2 = __riscv_vle32_v_f32m2(y + offset + 2 * vl, vl);
vfloat32m2_t vx3 = __riscv_vle32_v_f32m2(x + offset + 3 * vl, vl);
vfloat32m2_t vy3 = __riscv_vle32_v_f32m2(y + offset + 3 * vl, vl);

// Parallel FMACC operations
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx0, vy0, vl);
vacc1 = __riscv_vfmacc_vv_f32m2_tu(vacc1, vx1, vy1, vl);
vacc2 = __riscv_vfmacc_vv_f32m2_tu(vacc2, vx2, vy2, vl);
vacc3 = __riscv_vfmacc_vv_f32m2_tu(vacc3, vx3, vy3, vl);
vfloat32m8_t vx = __riscv_vle32_v_f32m8(x + offset, vl);
vfloat32m8_t vy = __riscv_vle32_v_f32m8(y + offset, vl);

offset += 4 * vl;
d -= 4 * vl;
}

// Merge accumulators
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc1, vlmax);
vacc2 = __riscv_vfadd_vv_f32m2(vacc2, vacc3, vlmax);
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc2, vlmax);

// Handle remaining elements
while (d > 0) {
size_t vl = __riscv_vsetvl_e32m2(d);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(x + offset, vl);
vfloat32m2_t vy = __riscv_vle32_v_f32m2(y + offset, vl);
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx, vy, vl);
acc = __riscv_vfmacc_vv_f32m8_tu(acc, vx, vy, vl);

offset += vl;
d -= vl;
}

// Final reduction
vfloat32m1_t sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
sum_scalar = __riscv_vfredusum_vs_f32m2_f32m1(vacc0, sum_scalar, vlmax);
vfloat32m1_t sum = __riscv_vfmv_s_f_f32m1(0.0f, 1);
sum = __riscv_vfredusum_vs_f32m8_f32m1(acc, sum, vlmax);

return __riscv_vfmv_f_s_f32m1_f32(sum_scalar);
return __riscv_vfmv_f_s_f32m1_f32(sum);
}

float
fvec_L2sqr_rvv(const float* x, const float* y, size_t d) {
size_t vlmax = __riscv_vsetvlmax_e32m2();
vfloat32m2_t vacc0 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc1 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc3 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
const size_t vlmax = __riscv_vsetvlmax_e32m8();
vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.0f, vlmax);

size_t offset = 0;
while (d >= 4 * vlmax) {
size_t vl = vlmax;
vfloat32m2_t vx0 = __riscv_vle32_v_f32m2(x + offset, vl);
vfloat32m2_t vy0 = __riscv_vle32_v_f32m2(y + offset, vl);
vfloat32m2_t vx1 = __riscv_vle32_v_f32m2(x + offset + vl, vl);
vfloat32m2_t vy1 = __riscv_vle32_v_f32m2(y + offset + vl, vl);
vfloat32m2_t vx2 = __riscv_vle32_v_f32m2(x + offset + 2 * vl, vl);
vfloat32m2_t vy2 = __riscv_vle32_v_f32m2(y + offset + 2 * vl, vl);
vfloat32m2_t vx3 = __riscv_vle32_v_f32m2(x + offset + 3 * vl, vl);
vfloat32m2_t vy3 = __riscv_vle32_v_f32m2(y + offset + 3 * vl, vl);
vfloat32m2_t vtmp0 = __riscv_vfsub_vv_f32m2(vx0, vy0, vl);
vfloat32m2_t vtmp1 = __riscv_vfsub_vv_f32m2(vx1, vy1, vl);
vfloat32m2_t vtmp2 = __riscv_vfsub_vv_f32m2(vx2, vy2, vl);
vfloat32m2_t vtmp3 = __riscv_vfsub_vv_f32m2(vx3, vy3, vl);
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vtmp0, vtmp0, vl);
vacc1 = __riscv_vfmacc_vv_f32m2_tu(vacc1, vtmp1, vtmp1, vl);
vacc2 = __riscv_vfmacc_vv_f32m2_tu(vacc2, vtmp2, vtmp2, vl);
vacc3 = __riscv_vfmacc_vv_f32m2_tu(vacc3, vtmp3, vtmp3, vl);
offset += 4 * vl;
d -= 4 * vl;
}
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc1, vlmax);
vacc2 = __riscv_vfadd_vv_f32m2(vacc2, vacc3, vlmax);
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc2, vlmax);
while (d > 0) {
size_t vl = __riscv_vsetvl_e32m2(d);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(x + offset, vl);
vfloat32m2_t vy = __riscv_vle32_v_f32m2(y + offset, vl);
vfloat32m2_t vtmp = __riscv_vfsub_vv_f32m2(vx, vy, vl);
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vtmp, vtmp, vl);
while (offset < d) {
const size_t vl = __riscv_vsetvl_e32m8(d - offset);

vfloat32m8_t vx = __riscv_vle32_v_f32m8(x + offset, vl);
vfloat32m8_t vy = __riscv_vle32_v_f32m8(y + offset, vl);

vx = __riscv_vfsub_vv_f32m8(vx, vy, vl);
acc = __riscv_vfmacc_vv_f32m8_tu(acc, vx, vx, vl);

offset += vl;
d -= vl;
}
vfloat32m1_t sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
sum_scalar = __riscv_vfredusum_vs_f32m2_f32m1(vacc0, sum_scalar, vlmax);
return __riscv_vfmv_f_s_f32m1_f32(sum_scalar);

vfloat32m1_t sum = __riscv_vfmv_s_f_f32m1(0.0f, 1);
sum = __riscv_vfredusum_vs_f32m8_f32m1(acc, sum, vlmax);

return __riscv_vfmv_f_s_f32m1_f32(sum);
}

float
Expand Down Expand Up @@ -250,38 +193,23 @@ fvec_Linf_rvv(const float* x, const float* y, size_t d) {

float
fvec_norm_L2sqr_rvv(const float* x, size_t d) {
size_t vlmax = __riscv_vsetvlmax_e32m2();
vfloat32m2_t vacc0 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc1 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
vfloat32m2_t vacc3 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
const size_t vlmax = __riscv_vsetvlmax_e32m8();
vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(0.0f, vlmax);

size_t offset = 0;
while (d >= 4 * vlmax) {
size_t vl = vlmax;
vfloat32m2_t vx0 = __riscv_vle32_v_f32m2(x + offset, vl);
vfloat32m2_t vx1 = __riscv_vle32_v_f32m2(x + offset + vl, vl);
vfloat32m2_t vx2 = __riscv_vle32_v_f32m2(x + offset + 2 * vl, vl);
vfloat32m2_t vx3 = __riscv_vle32_v_f32m2(x + offset + 3 * vl, vl);
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx0, vx0, vl);
vacc1 = __riscv_vfmacc_vv_f32m2_tu(vacc1, vx1, vx1, vl);
vacc2 = __riscv_vfmacc_vv_f32m2_tu(vacc2, vx2, vx2, vl);
vacc3 = __riscv_vfmacc_vv_f32m2_tu(vacc3, vx3, vx3, vl);
offset += 4 * vl;
d -= 4 * vl;
}
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc1, vlmax);
vacc2 = __riscv_vfadd_vv_f32m2(vacc2, vacc3, vlmax);
vacc0 = __riscv_vfadd_vv_f32m2(vacc0, vacc2, vlmax);
while (d > 0) {
size_t vl = __riscv_vsetvl_e32m2(d);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(x + offset, vl);
vacc0 = __riscv_vfmacc_vv_f32m2_tu(vacc0, vx, vx, vl);
while (offset < d) {
const size_t vl = __riscv_vsetvl_e32m8(d - offset);

vfloat32m8_t vx = __riscv_vle32_v_f32m8(x + offset, vl);
acc = __riscv_vfmacc_vv_f32m8_tu(acc, vx, vx, vl);

offset += vl;
d -= vl;
}
vfloat32m1_t sum_scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
sum_scalar = __riscv_vfredusum_vs_f32m2_f32m1(vacc0, sum_scalar, vlmax);
return __riscv_vfmv_f_s_f32m1_f32(sum_scalar);

vfloat32m1_t sum = __riscv_vfmv_s_f_f32m1(0.0f, 1);
sum = __riscv_vfredusum_vs_f32m8_f32m1(acc, sum, vlmax);

return __riscv_vfmv_f_s_f32m1_f32(sum);
}

void
Expand Down
Loading