Skip to content

Commit 3cd74b9

Browse files
authored
Fix scalar_at in example (#7471)
## Summary Semantic merge. ## Testing N/A Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d7f07de commit 3cd74b9

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

vortex/examples/turboquant_vector_search.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ use vortex_tensor::vector::Vector;
7777

7878
/// Cosine threshold for the demo filter. The query comes from the test split, so it may or may not
7979
/// have nearby rows in the train split.
80-
const COSINE_THRESHOLD: f32 = 0.90;
80+
const COSINE_THRESHOLD: f32 = 0.85;
8181

8282
/// Slack for checking decoded rows against a predicate that was evaluated on TurboQuant's lossy
8383
/// readthrough representation.
@@ -99,13 +99,14 @@ async fn main() -> Result<()> {
9999

100100
let session = VortexSession::default().with_tokio();
101101
vortex_tensor::initialize(&session);
102+
let mut ctx = session.create_execution_ctx();
102103
println!("session initialized with tensor plugins");
103104

104105
let dataset = VectorDataset::CohereSmall100k; // This is one of the smaller datasets.
105106

106107
// Download the source parquet files.
107108
let dataset_paths = vector_dataset::download(dataset, TrainLayout::Single).await?;
108-
let (_id, query_vector) = get_query_vector(dataset_paths.test).await?;
109+
let (_id, query_vector) = get_query_vector(dataset_paths.test, &mut ctx).await?;
109110
println!(
110111
"query vector selected (id = {_id}, dim = {})",
111112
query_vector.len()
@@ -142,18 +143,21 @@ async fn main() -> Result<()> {
142143
verify_roundtrip(&session, &bytes, struct_array.clone()).await?;
143144

144145
println!("verifying filter pushdown with cosine similarity...");
145-
verify_filter_pushdown(&session, &bytes, &query_vector, struct_array).await?;
146+
verify_filter_pushdown(&session, &bytes, &query_vector, struct_array, &mut ctx).await?;
146147

147148
println!("all checks passed!");
148149
Ok(())
149150
}
150151

151-
async fn get_query_vector(query_vectors_path: PathBuf) -> Result<(usize, Vec<f32>)> {
152+
async fn get_query_vector(
153+
query_vectors_path: PathBuf,
154+
ctx: &mut ExecutionCtx,
155+
) -> Result<(usize, Vec<f32>)> {
152156
let test_vectors = parquet_to_vortex_chunks(query_vectors_path).await?;
153157

154158
// Get a random query vector.
155159
let idx = rand::random_range(0..test_vectors.len());
156-
let struct_scalar = test_vectors.scalar_at(idx)?;
160+
let struct_scalar = test_vectors.execute_scalar(idx, ctx)?;
157161
let id_scalar = struct_scalar
158162
.as_struct()
159163
.field("id")
@@ -260,6 +264,7 @@ async fn verify_filter_pushdown(
260264
bytes: &ByteBuffer,
261265
query: &[f32],
262266
original: ArrayRef,
267+
ctx: &mut ExecutionCtx,
263268
) -> Result<()> {
264269
// Build the filter as `cosine_similarity(emb, <query>) > threshold`. The RHS of
265270
// `CosineSimilarity` is a `lit(...)` wrapping a `Vector<f32, DIM>` scalar; during scan
@@ -297,13 +302,12 @@ async fn verify_filter_pushdown(
297302
// Materialize the matching rows and dump each `emb` vector so the reader can see what the
298303
// pushed-down filter actually selected. Vectors are truncated to the first few elements since
299304
// DIM is typically large.
300-
let mut ctx = session.create_execution_ctx();
301305
let filtered: StructArray = ChunkedArray::try_new(chunks, original.dtype().clone())?
302306
.into_array()
303-
.execute(&mut ctx)?;
307+
.execute(ctx)?;
304308

305309
let emb = filtered.unmasked_field_by_name("emb")?.clone();
306-
let flat = flatten_vector_column(emb, &mut ctx)?;
310+
let flat = flatten_vector_column(emb, ctx)?;
307311

308312
let dim = query.len();
309313
for (i, row) in flat.chunks_exact(dim).enumerate() {

0 commit comments

Comments
 (0)