Skip to content

Commit 7822f7d

Browse files
committed
fix: use f64 precision for UDTF query vectors, add UDTF tests
Parse query vectors as f64 to match the optimizer path's precision, avoiding silent accuracy loss for F64-quantized indexes. Add 5 tests for vector_search_vector: basic happy path, projection pushdown, bad table ref error, registry miss error, k > dataset size.
1 parent d51d8ba commit 7822f7d

2 files changed

Lines changed: 107 additions & 18 deletions

File tree

src/udtf.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl TableFunctionImpl for VectorSearchVectorUDTF {
6969

7070
let table_ref = extract_string_literal(&exprs[0])?;
7171
let column = extract_string_literal(&exprs[1])?;
72-
let query_vec = extract_f32_vec(&exprs[2])?;
72+
let query_vec = extract_f64_vec(&exprs[2])?;
7373
let k = extract_usize_literal(&exprs[3])?;
7474

7575
// Build the registry key: "conn::schema::table::column"
@@ -97,7 +97,7 @@ impl TableFunctionImpl for VectorSearchVectorUDTF {
9797
/// fetches full rows via the lookup provider, and appends `_distance`.
9898
struct VectorSearchVectorProvider {
9999
registered: Arc<RegisteredTable>,
100-
query_vec: Vec<f32>,
100+
query_vec: Vec<f64>,
101101
k: usize,
102102
}
103103

@@ -130,10 +130,9 @@ impl TableProvider for VectorSearchVectorProvider {
130130
_limit: Option<usize>,
131131
) -> Result<Arc<dyn ExecutionPlan>> {
132132
// 1. HNSW search
133-
let query_f64: Vec<f64> = self.query_vec.iter().map(|&v| v as f64).collect();
134133
let matches = usearch_search(
135134
&self.registered.index,
136-
&query_f64,
135+
&self.query_vec,
137136
self.k,
138137
self.registered.scalar_kind,
139138
)?;
@@ -311,7 +310,7 @@ fn extract_usize_literal(expr: &Expr) -> Result<usize> {
311310
}
312311
}
313312

314-
fn extract_f32_vec(expr: &Expr) -> Result<Vec<f32>> {
313+
fn extract_f64_vec(expr: &Expr) -> Result<Vec<f64>> {
315314
use arrow_array::{Float32Array, Float64Array};
316315

317316
match expr {
@@ -320,11 +319,11 @@ fn extract_f32_vec(expr: &Expr) -> Result<Vec<f32>> {
320319
return Err(DataFusionError::Execution("Empty query vector".into()));
321320
}
322321
let inner = arr.value(0);
323-
if let Some(f32a) = inner.as_any().downcast_ref::<Float32Array>() {
324-
return Ok(f32a.values().to_vec());
325-
}
326322
if let Some(f64a) = inner.as_any().downcast_ref::<Float64Array>() {
327-
return Ok(f64a.values().iter().map(|&v| v as f32).collect());
323+
return Ok(f64a.values().to_vec());
324+
}
325+
if let Some(f32a) = inner.as_any().downcast_ref::<Float32Array>() {
326+
return Ok(f32a.values().iter().map(|&v| v as f64).collect());
328327
}
329328
Err(DataFusionError::Execution(
330329
"FixedSizeList inner is not Float32/Float64".into(),
@@ -335,11 +334,11 @@ fn extract_f32_vec(expr: &Expr) -> Result<Vec<f32>> {
335334
return Err(DataFusionError::Execution("Empty query vector".into()));
336335
}
337336
let inner = arr.value(0);
338-
if let Some(f32a) = inner.as_any().downcast_ref::<Float32Array>() {
339-
return Ok(f32a.values().to_vec());
340-
}
341337
if let Some(f64a) = inner.as_any().downcast_ref::<Float64Array>() {
342-
return Ok(f64a.values().iter().map(|&v| v as f32).collect());
338+
return Ok(f64a.values().to_vec());
339+
}
340+
if let Some(f32a) = inner.as_any().downcast_ref::<Float32Array>() {
341+
return Ok(f32a.values().iter().map(|&v| v as f64).collect());
343342
}
344343
Err(DataFusionError::Execution(
345344
"List scalar inner is not Float32/Float64".into(),
@@ -349,10 +348,10 @@ fn extract_f32_vec(expr: &Expr) -> Result<Vec<f32>> {
349348
let mut result = Vec::with_capacity(sf.args.len());
350349
for arg in &sf.args {
351350
match arg {
352-
Expr::Literal(ScalarValue::Float64(Some(v)), _) => result.push(*v as f32),
353-
Expr::Literal(ScalarValue::Float32(Some(v)), _) => result.push(*v),
354-
Expr::Literal(ScalarValue::Int64(Some(v)), _) => result.push(*v as f32),
355-
Expr::Literal(ScalarValue::Int32(Some(v)), _) => result.push(*v as f32),
351+
Expr::Literal(ScalarValue::Float64(Some(v)), _) => result.push(*v),
352+
Expr::Literal(ScalarValue::Float32(Some(v)), _) => result.push(*v as f64),
353+
Expr::Literal(ScalarValue::Int64(Some(v)), _) => result.push(*v as f64),
354+
Expr::Literal(ScalarValue::Int32(Some(v)), _) => result.push(*v as f64),
356355
other => {
357356
return Err(DataFusionError::Execution(format!(
358357
"Non-literal in ARRAY[...]: {other:?}"
@@ -363,7 +362,7 @@ fn extract_f32_vec(expr: &Expr) -> Result<Vec<f32>> {
363362
Ok(result)
364363
}
365364
other => Err(DataFusionError::Execution(format!(
366-
"Cannot extract f32 vector from: {other:?}"
365+
"Cannot extract f64 vector from: {other:?}"
367366
))),
368367
}
369368
}

tests/execution.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,3 +937,93 @@ async fn udf_dimension_mismatch_select_star() {
937937
"expected dimension mismatch error, got: {msg}"
938938
);
939939
}
940+
941+
// ═══════════════════════════════════════════════════════════════════════════════
942+
// vector_search_vector UDTF tests
943+
// ═══════════════════════════════════════════════════════════════════════════════
944+
945+
/// Basic happy path: returns correct rows with _distance column.
946+
#[tokio::test]
947+
async fn udtf_vector_search_vector_basic() {
948+
let ctx = make_exec_ctx("conn::schema::items::vector").await;
949+
let sql = "SELECT id, label, _distance FROM vector_search_vector('conn.schema.items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 3) ORDER BY _distance ASC";
950+
let df = ctx.sql(sql).await.expect("sql");
951+
let batches = df.collect().await.expect("collect");
952+
953+
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
954+
assert_eq!(total, 3, "expected 3 results");
955+
956+
// First result should be row 1 (exact match, distance 0)
957+
let ids = batches[0]
958+
.column(0)
959+
.as_any()
960+
.downcast_ref::<UInt64Array>()
961+
.expect("id col");
962+
assert_eq!(ids.value(0), 1, "closest must be row 1");
963+
964+
let dists = batches[0]
965+
.column(2)
966+
.as_any()
967+
.downcast_ref::<Float32Array>()
968+
.expect("_distance col");
969+
assert!(
970+
(dists.value(0) - 0.0).abs() < 1e-6,
971+
"row 1 distance must be 0.0, got {}",
972+
dists.value(0)
973+
);
974+
}
975+
976+
/// Projection pushdown: only requested columns are returned.
977+
#[tokio::test]
978+
async fn udtf_vector_search_vector_projection() {
979+
let ctx = make_exec_ctx("conn::schema::items::vector").await;
980+
let sql = "SELECT id, _distance FROM vector_search_vector('conn.schema.items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 2)";
981+
let df = ctx.sql(sql).await.expect("sql");
982+
let batches = df.collect().await.expect("collect");
983+
assert_eq!(
984+
batches[0].num_columns(),
985+
2,
986+
"expected 2 columns (id, _distance), got {}",
987+
batches[0].num_columns()
988+
);
989+
let schema = batches[0].schema();
990+
let col_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
991+
assert_eq!(col_names, vec!["id", "_distance"]);
992+
}
993+
994+
/// parse_dot_table_ref error: fewer than 3 parts.
995+
#[tokio::test]
996+
async fn udtf_vector_search_vector_bad_table_ref() {
997+
let ctx = make_exec_ctx("conn::schema::items::vector").await;
998+
let sql = "SELECT * FROM vector_search_vector('items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 3)";
999+
let err = ctx.sql(sql).await.unwrap_err();
1000+
let msg = err.to_string();
1001+
assert!(
1002+
msg.contains("connection.schema.table"),
1003+
"expected table ref format error, got: {msg}"
1004+
);
1005+
}
1006+
1007+
/// Registry miss: column not in registry returns clear error.
1008+
#[tokio::test]
1009+
async fn udtf_vector_search_vector_registry_miss() {
1010+
let ctx = make_exec_ctx("conn::schema::items::vector").await;
1011+
let sql = "SELECT * FROM vector_search_vector('conn.schema.items', 'nonexistent', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 3)";
1012+
let err = ctx.sql(sql).await.unwrap_err();
1013+
let msg = err.to_string();
1014+
assert!(
1015+
msg.contains("no loaded vector index"),
1016+
"expected registry miss error, got: {msg}"
1017+
);
1018+
}
1019+
1020+
/// Empty result: search with k larger than dataset returns all rows.
1021+
#[tokio::test]
1022+
async fn udtf_vector_search_vector_k_larger_than_dataset() {
1023+
let ctx = make_exec_ctx("conn::schema::items::vector").await;
1024+
let sql = "SELECT id, _distance FROM vector_search_vector('conn.schema.items', 'vector', ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float], 100)";
1025+
let df = ctx.sql(sql).await.expect("sql");
1026+
let batches = df.collect().await.expect("collect");
1027+
let total: usize = batches.iter().map(|b| b.num_rows()).sum();
1028+
assert_eq!(total, 4, "expected all 4 rows when k > dataset size");
1029+
}

0 commit comments

Comments
 (0)