Skip to content

Commit a48ebd0

Browse files
committed
optimize inner product and update boilerplate
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 1a56560 commit a48ebd0

9 files changed

Lines changed: 392 additions & 252 deletions

File tree

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-tensor/benches/similarity_search.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod common;
2626

2727
use common::Variant;
2828
use common::build_similarity_search_tree;
29+
use common::build_variant;
2930
use common::extract_row_as_query;
3031
use common::generate_random_vectors;
3132

@@ -64,14 +65,8 @@ fn bench_variant(bencher: Bencher<'_, '_>, variant: Variant) {
6465
// the query identical across all three variants.
6566
let raw = generate_random_vectors(NUM_ROWS, DIM, SEED);
6667
let query = extract_row_as_query(&raw, 0, DIM);
67-
let data = match variant {
68-
Variant::Uncompressed => raw,
69-
Variant::DefaultCompression => {
70-
common::compress_default(raw).vortex_expect("default compression succeeds")
71-
}
72-
Variant::TurboQuant => common::compress_turboquant(raw, &mut ctx)
73-
.vortex_expect("turboquant compression succeeds"),
74-
};
68+
let data = build_variant(variant, NUM_ROWS, DIM, SEED, &mut ctx)
69+
.vortex_expect("variant build succeeds");
7570

7671
// println!(
7772
// "\n\n{}: {}\n\n",

vortex-tensor/benches/similarity_search_common/mod.rs

Lines changed: 26 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -30,36 +30,30 @@ use vortex_array::ArrayRef;
3030
use vortex_array::ExecutionCtx;
3131
use vortex_array::IntoArray;
3232
use vortex_array::VortexSessionExecute;
33-
use vortex_array::arrays::ConstantArray;
3433
use vortex_array::arrays::Extension;
3534
use vortex_array::arrays::ExtensionArray;
3635
use vortex_array::arrays::FixedSizeListArray;
3736
use vortex_array::arrays::PrimitiveArray;
3837
use vortex_array::arrays::extension::ExtensionArrayExt;
3938
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
4039
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
41-
use vortex_array::builtins::ArrayBuiltins;
42-
use vortex_array::dtype::DType;
43-
use vortex_array::dtype::Nullability;
44-
use vortex_array::dtype::PType;
4540
use vortex_array::dtype::extension::ExtDType;
4641
use vortex_array::extension::EmptyMetadata;
47-
use vortex_array::scalar::Scalar;
48-
use vortex_array::scalar_fn::fns::operators::Operator;
4942
use vortex_array::session::ArraySession;
5043
use vortex_array::validity::Validity;
5144
use vortex_btrblocks::BtrBlocksCompressor;
5245
use vortex_buffer::BufferMut;
5346
use vortex_error::VortexExpect;
5447
use vortex_error::VortexResult;
48+
use vortex_error::vortex_bail;
5549
use vortex_error::vortex_panic;
5650
use vortex_session::VortexSession;
5751
use vortex_tensor::encodings::turboquant::TurboQuantConfig;
5852
use vortex_tensor::encodings::turboquant::turboquant_encode_unchecked;
59-
use vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity;
6053
use vortex_tensor::scalar_fns::l2_denorm::L2Denorm;
6154
use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm;
6255
use vortex_tensor::vector::Vector;
56+
pub use vortex_tensor::vector_search::build_similarity_search_tree;
6357

6458
/// A shared [`VortexSession`] pre-loaded with the builtin [`ArraySession`] so both bench and
6559
/// example can create execution contexts cheaply.
@@ -146,65 +140,16 @@ pub fn extract_row_as_query(vectors: &ArrayRef, row: usize, dim: u32) -> Vec<f32
146140
slice[start..start + dim_usize].to_vec()
147141
}
148142

149-
/// Build a `Vector<dim, f32>` extension array whose storage is a [`ConstantArray`] broadcasting a
150-
/// single query vector across `num_rows` rows. This is how we hand a single query vector to
151-
/// `CosineSimilarity` on the `rhs` side -- `ScalarFnArray` requires both children to have the
152-
/// same length, so we broadcast the query instead of hand-rolling a 1-row input.
153-
fn build_constant_query_vector(query: &[f32], num_rows: usize) -> VortexResult<ArrayRef> {
154-
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
155-
156-
let children: Vec<Scalar> = query
157-
.iter()
158-
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
159-
.collect();
160-
let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
161-
162-
let storage = ConstantArray::new(storage_scalar, num_rows).into_array();
163-
164-
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();
165-
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
166-
}
167-
168-
/// Compresses a raw `Vector<dim, f32>` array with the default BtrBlocks pipeline.
169-
///
170-
/// [`BtrBlocksCompressor`] walks into the extension array and recursively compresses the
171-
/// underlying FSL storage child. TurboQuant is *not* exercised by this path -- it is not
172-
/// registered in the default scheme set -- so this measures "generic" lossless compression
173-
/// applied to float vectors.
174-
pub fn compress_default(data: ArrayRef) -> VortexResult<ArrayRef> {
175-
BtrBlocksCompressor::default().compress(&data)
176-
}
177-
178-
/// Compresses a raw `Vector<dim, f32>` array with the TurboQuant pipeline by hand, producing the
179-
/// same tree shape that
180-
/// [`vortex_tensor::encodings::turboquant::TurboQuantScheme`] would:
181-
///
182-
/// ```text
183-
/// L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)
184-
/// ```
185-
///
186-
/// Calling the encode helpers directly (instead of going through
187-
/// `BtrBlocksCompressorBuilder::with_turboquant()`) lets this example avoid depending on the
188-
/// `unstable_encodings` feature flag.
189-
///
190-
/// See `vortex-tensor/src/encodings/turboquant/tests/mod.rs::normalize_and_encode` for the same
191-
/// canonical recipe.
192-
pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
143+
fn normalize_vectors(
144+
data: ArrayRef,
145+
ctx: &mut ExecutionCtx,
146+
) -> VortexResult<(ArrayRef, ArrayRef, usize)> {
193147
let l2_denorm = normalize_as_l2_denorm(data, ctx)?;
194148
let normalized = l2_denorm.child_at(0).clone();
195149
let norms = l2_denorm.child_at(1).clone();
196150
let num_rows = l2_denorm.len();
197151

198-
let normalized_ext = normalized
199-
.as_opt::<Extension>()
200-
.vortex_expect("normalized child should be an Extension array");
201-
202-
let config = TurboQuantConfig::default();
203-
// SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero), which is the
204-
// invariant `turboquant_encode_unchecked` expects.
205-
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?;
206-
207-
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
152+
Ok((normalized, norms, num_rows))
208153
}
209154

210155
/// Dispatch helper that builds the data array for the requested [`Variant`], starting from a
@@ -220,37 +165,24 @@ pub fn build_variant(
220165
let raw = generate_random_vectors(num_rows, dim, seed);
221166
match variant {
222167
Variant::Uncompressed => Ok(raw),
223-
Variant::DefaultCompression => compress_default(raw),
224-
Variant::TurboQuant => compress_turboquant(raw, ctx),
225-
}
226-
}
227-
228-
/// Build the lazy similarity-search array tree for a prepared data array and a single query
229-
/// vector. The returned tree is a boolean array of length `data.len()` where position `i` is
230-
/// `true` iff `cosine_similarity(data[i], query) > threshold`.
231-
///
232-
/// The tree shape is:
233-
///
234-
/// ```text
235-
/// Binary(Gt, [
236-
/// CosineSimilarity([data, ConstantArray(query_vec, n)]),
237-
/// ConstantArray(threshold, n),
238-
/// ])
239-
/// ```
240-
///
241-
/// This function does no execution; it is safe to call inside a benchmark setup closure.
242-
pub fn build_similarity_search_tree(
243-
data: ArrayRef,
244-
query: &[f32],
245-
threshold: f32,
246-
) -> VortexResult<ArrayRef> {
247-
let num_rows = data.len();
248-
let query_vec = build_constant_query_vector(query, num_rows)?;
249-
250-
let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array();
251-
252-
let threshold_scalar = Scalar::primitive(threshold, Nullability::NonNullable);
253-
let threshold_array = ConstantArray::new(threshold_scalar, num_rows).into_array();
168+
Variant::DefaultCompression => {
169+
let (normalized, norms, num_rows) = normalize_vectors(raw, ctx)?;
170+
let compressed = BtrBlocksCompressor::default().compress(&normalized)?;
254171

255-
cosine.binary(threshold_array, Operator::Gt)
172+
Ok(unsafe { L2Denorm::new_array_unchecked(compressed, norms, num_rows) }?.into_array())
173+
}
174+
Variant::TurboQuant => {
175+
let (normalized, norms, num_rows) = normalize_vectors(raw, ctx)?;
176+
let Some(normalized_ext) = normalized.as_opt::<Extension>() else {
177+
vortex_bail!("normalize_as_l2_denorm must produce an Extension array child");
178+
};
179+
180+
let config = TurboQuantConfig::default();
181+
// SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero),
182+
// which is the invariant `turboquant_encode_unchecked` expects.
183+
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?;
184+
185+
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
186+
}
187+
}
256188
}

vortex-tensor/public-api.lock

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,4 +550,12 @@ impl core::marker::Copy for vortex_tensor::vector::VectorMatcherMetadata
550550

551551
impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherMetadata
552552

553+
pub mod vortex_tensor::vector_search
554+
555+
pub fn vortex_tensor::vector_search::build_constant_query_vector<T: vortex_array::dtype::ptype::NativePType + core::convert::Into<vortex_array::scalar::typed_view::primitive::pvalue::PValue>>(query: &[T], num_rows: usize) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
556+
557+
pub fn vortex_tensor::vector_search::build_similarity_search_tree<T: vortex_array::dtype::ptype::NativePType + core::convert::Into<vortex_array::scalar::typed_view::primitive::pvalue::PValue>>(data: vortex_array::array::erased::ArrayRef, query: &[T], threshold: T) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
558+
559+
pub fn vortex_tensor::vector_search::compress_turboquant(data: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
560+
553561
pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession)

vortex-tensor/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ pub mod vector;
2525

2626
pub mod encodings;
2727

28+
pub mod vector_search;
29+
2830
mod utils;
2931

3032
/// Initialize the Vortex tensor library with a Vortex session.

vortex-tensor/src/scalar_fns/inner_product.rs

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -527,18 +527,9 @@ impl InnerProduct {
527527
let values: &[f32] = values_prim.as_slice::<f32>();
528528
debug_assert_eq!(codes.len(), len * padded_dim);
529529

530-
// Direct codebook lookup in the hot loop. See the function doc comment for why this
531-
// beats an explicit product table here.
532-
let mut out = BufferMut::<f32>::with_capacity(len);
533-
for row in 0..len {
534-
let row_codes = &codes[row * padded_dim..(row + 1) * padded_dim];
535-
let mut acc = 0.0f32;
536-
for j in 0..padded_dim {
537-
acc += q[j] * values[row_codes[j] as usize];
538-
}
539-
// SAFETY: we reserved `len` slots above and push exactly once per row.
540-
unsafe { out.push_unchecked(acc) };
541-
}
530+
// The hot loop is extracted into [`execute_dict_constant_inner_product`] with
531+
// unchecked indexing so the compiler can vectorize the inner gather-accumulate.
532+
let out = execute_dict_constant_inner_product(q, values, codes, len, padded_dim);
542533

543534
// SAFETY: the buffer length equals `len`, which matches the validity length.
544535
let result = unsafe { PrimitiveArray::new_unchecked(out.freeze(), validity) }.into_array();
@@ -556,6 +547,51 @@ fn inner_product_row<T: Float + NativePType>(a: &[T], b: &[T]) -> T {
556547
.fold(T::zero(), |acc, v| acc + v)
557548
}
558549

550+
/// Compute inner products between a constant query vector and dictionary-encoded rows.
551+
///
552+
/// For each row, computes `sum(q[j] * values[codes[row * dim + j]])` using the codebook
553+
/// `values` directly instead of decoding the dictionary into dense vectors.
554+
///
555+
/// The inner loop uses four independent accumulators so the CPU can pipeline FP additions
556+
/// instead of waiting for each `fadd` to retire before starting the next.
557+
fn execute_dict_constant_inner_product(
558+
q: &[f32],
559+
values: &[f32],
560+
codes: &[u8],
561+
num_rows: usize,
562+
dim: usize,
563+
) -> BufferMut<f32> {
564+
let mut out = BufferMut::<f32>::with_capacity(num_rows);
565+
566+
for row_codes in codes.chunks_exact(dim) {
567+
let mut acc0 = 0.0f32;
568+
let mut acc1 = 0.0f32;
569+
let mut acc2 = 0.0f32;
570+
let mut acc3 = 0.0f32;
571+
572+
let code_chunks = row_codes.chunks_exact(4);
573+
let q_chunks = q.chunks_exact(4);
574+
let code_rem = code_chunks.remainder();
575+
let q_rem = q_chunks.remainder();
576+
577+
for (cc, qc) in code_chunks.zip(q_chunks) {
578+
acc0 += qc[0] * values[cc[0] as usize];
579+
acc1 += qc[1] * values[cc[1] as usize];
580+
acc2 += qc[2] * values[cc[2] as usize];
581+
acc3 += qc[3] * values[cc[3] as usize];
582+
}
583+
584+
for (&code, &q_val) in code_rem.iter().zip(q_rem.iter()) {
585+
acc0 += q_val * values[code as usize];
586+
}
587+
588+
// SAFETY: we reserved `num_rows` slots and push exactly once per row.
589+
unsafe { out.push_unchecked(acc0 + acc1 + acc2 + acc3) };
590+
}
591+
592+
out
593+
}
594+
559595
#[cfg(test)]
560596
mod tests {
561597
use std::sync::LazyLock;

0 commit comments

Comments
 (0)