|
| 1 | +#include "index/sparse/sindi_simd.h" |
| 2 | + |
| 3 | +#if defined(__x86_64__) |
| 4 | +#include <immintrin.h> |
| 5 | + |
| 6 | +namespace knowhere::sparse::inverted::sindi { |
| 7 | + |
| 8 | +void |
| 9 | +ip_scatter_avx2_fp16(float qval, const knowhere::fp16* vals, const uint16_t* ids, int32_t num, float* out) { |
| 10 | + int32_t i = 0; |
| 11 | + const __m256 vq = _mm256_set1_ps(qval); |
| 12 | + for (; i + 8 <= num; i += 8) { |
| 13 | + const uint16_t* hptr = reinterpret_cast<const uint16_t*>(vals + i); |
| 14 | + __m128i h = _mm_loadu_si128(reinterpret_cast<const __m128i*>(hptr)); |
| 15 | + __m256 v_vals = _mm256_cvtph_ps(h); |
| 16 | + __m256 v_mul = _mm256_mul_ps(v_vals, vq); |
| 17 | + |
| 18 | + __m128i idx16 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ids + i)); |
| 19 | + __m256i v_idx = _mm256_cvtepu16_epi32(idx16); |
| 20 | + __m256 v_old = _mm256_i32gather_ps(out, v_idx, 4); |
| 21 | + __m256 v_sum = _mm256_add_ps(v_old, v_mul); |
| 22 | + |
| 23 | + alignas(32) uint32_t tmp_idx[8]; |
| 24 | + alignas(32) float tmp_sum[8]; |
| 25 | + _mm256_store_si256(reinterpret_cast<__m256i*>(tmp_idx), v_idx); |
| 26 | + _mm256_store_ps(tmp_sum, v_sum); |
| 27 | + out[tmp_idx[0]] = tmp_sum[0]; |
| 28 | + out[tmp_idx[1]] = tmp_sum[1]; |
| 29 | + out[tmp_idx[2]] = tmp_sum[2]; |
| 30 | + out[tmp_idx[3]] = tmp_sum[3]; |
| 31 | + out[tmp_idx[4]] = tmp_sum[4]; |
| 32 | + out[tmp_idx[5]] = tmp_sum[5]; |
| 33 | + out[tmp_idx[6]] = tmp_sum[6]; |
| 34 | + out[tmp_idx[7]] = tmp_sum[7]; |
| 35 | + } |
| 36 | + for (; i < num; ++i) { |
| 37 | + out[ids[i]] += qval * static_cast<float>(vals[i]); |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +void |
| 42 | +bm25_scatter_avx2_u16(float qval, const uint16_t* vals, const uint16_t* ids, int32_t num, float* out, float k1, float b, |
| 43 | + float avgdl, const float* row_sums) { |
| 44 | + const float p1 = k1 + 1.0f; |
| 45 | + const float p2 = k1 * (1.0f - b); |
| 46 | + const float p3 = k1 * b / avgdl; |
| 47 | + |
| 48 | + int32_t i = 0; |
| 49 | + const __m256 vqval = _mm256_set1_ps(qval); |
| 50 | + const __m256 vp1 = _mm256_set1_ps(p1); |
| 51 | + const __m256 vp2 = _mm256_set1_ps(p2); |
| 52 | + const __m256 vp3 = _mm256_set1_ps(p3); |
| 53 | + |
| 54 | + for (; i + 8 <= num; i += 8) { |
| 55 | + const uint16_t* hptr = vals + i; |
| 56 | + __m128i h = _mm_loadu_si128(reinterpret_cast<const __m128i*>(hptr)); |
| 57 | + __m256i w = _mm256_cvtepu16_epi32(h); |
| 58 | + __m256 tf_vec = _mm256_cvtepi32_ps(w); |
| 59 | + |
| 60 | + __m128i idx16 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ids + i)); |
| 61 | + __m256i v_idx = _mm256_cvtepu16_epi32(idx16); |
| 62 | + __m256 dl_vec = _mm256_i32gather_ps(row_sums, v_idx, 4); |
| 63 | + |
| 64 | + __m256 numerator = _mm256_mul_ps(tf_vec, vp1); |
| 65 | + numerator = _mm256_mul_ps(numerator, vqval); |
| 66 | + |
| 67 | + __m256 denominator = _mm256_fmadd_ps(dl_vec, vp3, vp2); |
| 68 | + denominator = _mm256_add_ps(tf_vec, denominator); |
| 69 | + |
| 70 | + __m256 bm25_vec = _mm256_div_ps(numerator, denominator); |
| 71 | + |
| 72 | + __m256 v_old = _mm256_i32gather_ps(out, v_idx, 4); |
| 73 | + __m256 v_sum = _mm256_add_ps(v_old, bm25_vec); |
| 74 | + |
| 75 | + alignas(32) uint32_t tmp_idx[8]; |
| 76 | + alignas(32) float tmp_sum[8]; |
| 77 | + _mm256_store_si256(reinterpret_cast<__m256i*>(tmp_idx), v_idx); |
| 78 | + _mm256_store_ps(tmp_sum, v_sum); |
| 79 | + out[tmp_idx[0]] = tmp_sum[0]; |
| 80 | + out[tmp_idx[1]] = tmp_sum[1]; |
| 81 | + out[tmp_idx[2]] = tmp_sum[2]; |
| 82 | + out[tmp_idx[3]] = tmp_sum[3]; |
| 83 | + out[tmp_idx[4]] = tmp_sum[4]; |
| 84 | + out[tmp_idx[5]] = tmp_sum[5]; |
| 85 | + out[tmp_idx[6]] = tmp_sum[6]; |
| 86 | + out[tmp_idx[7]] = tmp_sum[7]; |
| 87 | + } |
| 88 | + |
| 89 | + for (; i < num; ++i) { |
| 90 | + float tf = static_cast<float>(vals[i]); |
| 91 | + uint16_t docid = ids[i]; |
| 92 | + float dl = row_sums[docid]; |
| 93 | + float bm25_score = qval * p1 * tf / (tf + p2 + p3 * dl); |
| 94 | + out[docid] += bm25_score; |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +void |
| 99 | +batch_insert_avx2(const float* scores, size_t docid_start, size_t count, |
| 100 | + knowhere::ResultMinHeap<float, uint32_t>& topk_q, float& threshold, const BitsetView& bitset) { |
| 101 | + size_t i = 0; |
| 102 | + __m256 vthr = _mm256_set1_ps(threshold); |
| 103 | + for (; i + 8 <= count; i += 8) { |
| 104 | + _mm_prefetch(reinterpret_cast<const char*>(scores + i + 32), _MM_HINT_T0); |
| 105 | + __m256 v = _mm256_loadu_ps(scores + i); |
| 106 | + __m256 cmp = _mm256_cmp_ps(v, vthr, _CMP_GT_OQ); |
| 107 | + int mm = _mm256_movemask_ps(cmp); |
| 108 | + while (mm != 0) { |
| 109 | + unsigned bit = __builtin_ctz(static_cast<unsigned>(mm)); |
| 110 | + mm &= (mm - 1); |
| 111 | + size_t idx = i + bit; |
| 112 | + if (!bitset.empty() && bitset.test(static_cast<int64_t>(docid_start + idx))) { |
| 113 | + continue; |
| 114 | + } |
| 115 | + float s = scores[idx]; |
| 116 | + if (topk_q.Push(s, static_cast<uint32_t>(docid_start + idx))) { |
| 117 | + if (topk_q.Full()) { |
| 118 | + threshold = topk_q.Threshold(); |
| 119 | + vthr = _mm256_set1_ps(threshold); |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + for (; i < count; ++i) { |
| 125 | + float s = scores[i]; |
| 126 | + if (s <= threshold) { |
| 127 | + continue; |
| 128 | + } |
| 129 | + if (!bitset.empty() && bitset.test(static_cast<int64_t>(docid_start + i))) { |
| 130 | + continue; |
| 131 | + } |
| 132 | + if (topk_q.Push(s, static_cast<uint32_t>(docid_start + i))) { |
| 133 | + if (topk_q.Full()) { |
| 134 | + threshold = topk_q.Threshold(); |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +} // namespace knowhere::sparse::inverted::sindi |
| 141 | + |
| 142 | +#endif // __x86_64__ |
0 commit comments