Skip to content

Commit bfc0d10

Browse files
authored
fix(rule): preserve hidden distance for sort (#21)
* fix(rule): preserve hidden distance for sort Keep _distance in an inner projection when ORDER BY uses a vector\ndistance expression that is not part of the final select list.\n\nThis fixes split-provider execution for queries like SELECT id ORDER\nBY l2_distance(vector, ARRAY[...]) LIMIT k while preserving the final\noutput schema. Add an execution test for the direct ORDER BY shape to\ncover the production case. * style(rule): format hidden distance rewrite * test(rule): cover computed sort projections
1 parent 2336126 commit bfc0d10

2 files changed

Lines changed: 164 additions & 19 deletions

File tree

src/rule.rs

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
//
1919
// Replacement:
2020
//
21-
// Sort(fetch=k) ← kept (sort order)
22-
// Projection([col(a), col(b), col("_distance").alias("dist")])
23-
// USearchNode ← executes ANN
21+
// Projection([final output cols])
22+
// Sort(fetch=k)
23+
// Projection([final output cols + optional hidden _distance])
24+
// USearchNode
2425

2526
use std::collections::HashMap;
2627
use std::sync::Arc;
@@ -151,25 +152,33 @@ impl USearchRule {
151152
node: Arc::new(node) as Arc<dyn UserDefinedLogicalNode>,
152153
}));
153154

154-
// Build Projection over USearchNode matching the original output schema.
155+
// Build the final user-visible projection over USearchNode output.
155156
let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance");
156-
let new_proj_exprs = if proj_exprs_slice.is_empty() {
157+
let final_proj_exprs = if proj_exprs_slice.is_empty() {
157158
passthrough_projection(&vsn_df_schema, &table_ref)
158159
} else {
159160
remap_projections(proj_exprs_slice, dist_alias_str, &table_ref)
160161
};
161-
let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?;
162-
163-
// Keep the Sort node so DataFusion handles ordering by _distance / dist.
164-
// USearch returns results in arbitrary (internal) order when the underlying
165-
// data is fetched from the TableProvider.
166-
Some(LogicalPlan::Sort(
167-
datafusion::logical_expr::logical_plan::Sort {
168-
expr: sort.expr.clone(),
169-
input: Arc::new(LogicalPlan::Projection(new_proj)),
170-
fetch: sort.fetch,
171-
},
172-
))
162+
let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref());
163+
let needs_hidden_distance = remapped_sort_exprs.iter().any(
164+
|e| matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance"),
165+
) && !projection_exposes_name(&final_proj_exprs, "_distance");
166+
167+
let mut sort_input_exprs = final_proj_exprs.clone();
168+
if needs_hidden_distance {
169+
sort_input_exprs.push(col("_distance"));
170+
}
171+
172+
let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?;
173+
let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort {
174+
expr: remapped_sort_exprs,
175+
input: Arc::new(LogicalPlan::Projection(sort_input)),
176+
fetch: sort.fetch,
177+
});
178+
179+
let outer_proj_exprs = build_outer_projection(&final_proj_exprs);
180+
let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?;
181+
Some(LogicalPlan::Projection(outer_proj))
173182
}
174183
}
175184

@@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo
283292
}
284293

285294
fn is_distance_expr(expr: &Expr) -> bool {
286-
matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
295+
let inner = match expr {
296+
Expr::Alias(a) => a.expr.as_ref(),
297+
other => other,
298+
};
299+
matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
287300
}
288301

289302
fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec<f64>)> {
@@ -322,6 +335,46 @@ fn remap_projections(
322335
.collect()
323336
}
324337

338+
fn remap_sort_exprs(
339+
sort_exprs: &[datafusion::logical_expr::SortExpr],
340+
dist_alias_name: Option<&str>,
341+
) -> Vec<datafusion::logical_expr::SortExpr> {
342+
sort_exprs
343+
.iter()
344+
.map(|sort_expr| {
345+
let remapped_expr = match &sort_expr.expr {
346+
Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()),
347+
expr if is_distance_expr(expr) => col("_distance"),
348+
other => other.clone(),
349+
};
350+
datafusion::logical_expr::SortExpr {
351+
expr: remapped_expr,
352+
asc: sort_expr.asc,
353+
nulls_first: sort_expr.nulls_first,
354+
}
355+
})
356+
.collect()
357+
}
358+
359+
fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool {
360+
exprs.iter().any(|expr| match expr {
361+
Expr::Alias(a) => a.name == name,
362+
Expr::Column(c) => c.name == name,
363+
_ => false,
364+
})
365+
}
366+
367+
fn build_outer_projection(exprs: &[Expr]) -> Vec<Expr> {
368+
exprs
369+
.iter()
370+
.map(|expr| match expr {
371+
Expr::Alias(a) => col(a.name.as_str()),
372+
Expr::Column(c) => Expr::Column(c.clone()),
373+
other => col(other.schema_name().to_string()),
374+
})
375+
.collect()
376+
}
377+
325378
/// Build a passthrough Projection for SELECT * queries (no original Projection node).
326379
/// Projects only the original table columns (not `_distance`) so the output schema
327380
/// matches the original Sort schema. The Sort re-evaluates the distance UDF expression

tests/execution.rs

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
use std::sync::Arc;
2020

2121
use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
22-
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray, UInt64Array};
22+
use arrow_array::{
23+
FixedSizeListArray, Float32Array, Int64Array, RecordBatch, StringArray, UInt64Array,
24+
};
2325
use arrow_schema::{DataType, Field, Schema};
2426
use datafusion::execution::session_state::SessionStateBuilder;
2527
use datafusion::prelude::SessionContext;
@@ -152,6 +154,60 @@ async fn collect_ids(ctx: &SessionContext, sql: &str) -> Vec<u64> {
152154
ids
153155
}
154156

157+
/// Collect a named integer column from a query result.
158+
async fn collect_i64_column(ctx: &SessionContext, sql: &str, column_name: &str) -> Vec<i64> {
159+
let df = ctx
160+
.sql(sql)
161+
.await
162+
.unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}"));
163+
let batches = df
164+
.collect()
165+
.await
166+
.unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}"));
167+
168+
let mut values: Vec<i64> = vec![];
169+
for batch in &batches {
170+
let col_idx = batch
171+
.schema()
172+
.index_of(column_name)
173+
.unwrap_or_else(|e| panic!("no '{column_name}' column in result: {e}\nSQL: {sql}"));
174+
let column = batch.column(col_idx);
175+
if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() {
176+
values.extend(arr.values().iter().map(|v| *v as i64));
177+
} else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
178+
values.extend(arr.values());
179+
} else {
180+
panic!("column '{column_name}' not Int64/UInt64\nSQL: {sql}");
181+
}
182+
}
183+
values
184+
}
185+
186+
/// Collect the first integer column from a query result.
187+
async fn collect_first_i64_column(ctx: &SessionContext, sql: &str) -> Vec<i64> {
188+
let df = ctx
189+
.sql(sql)
190+
.await
191+
.unwrap_or_else(|e| panic!("sql() failed: {e}\nSQL: {sql}"));
192+
let batches = df
193+
.collect()
194+
.await
195+
.unwrap_or_else(|e| panic!("collect() failed: {e}\nSQL: {sql}"));
196+
197+
let mut values: Vec<i64> = vec![];
198+
for batch in &batches {
199+
let column = batch.column(0);
200+
if let Some(arr) = column.as_any().downcast_ref::<UInt64Array>() {
201+
values.extend(arr.values().iter().map(|v| *v as i64));
202+
} else if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
203+
values.extend(arr.values());
204+
} else {
205+
panic!("first result column not Int64/UInt64\nSQL: {sql}");
206+
}
207+
}
208+
values
209+
}
210+
155211
const Q: &str = "ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]";
156212

157213
// ═══════════════════════════════════════════════════════════════════════════════
@@ -530,6 +586,42 @@ async fn exec_split_provider_select_specific_columns() {
530586
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
531587
}
532588

589+
/// SELECT specific columns without projecting the distance expression.
590+
/// This matches the split-provider direct ORDER BY shape used by callers that
591+
/// rewrite higher-level search helpers into the low-level distance UDF.
592+
#[tokio::test]
593+
async fn exec_split_provider_order_by_udf_direct() {
594+
let ctx = make_split_provider_ctx("items::vector").await;
595+
let sql = format!("SELECT id FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2");
596+
let ids = collect_ids(&ctx, &sql).await;
597+
assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}");
598+
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
599+
}
600+
601+
/// Direct ORDER BY UDF with an aliased computed projection must preserve the
602+
/// computed output through the rewrite.
603+
#[tokio::test]
604+
async fn exec_split_provider_order_by_udf_with_computed_alias() {
605+
let ctx = make_split_provider_ctx("items::vector").await;
606+
let sql = format!(
607+
"SELECT CAST(id + 1 AS BIGINT) AS id_plus FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"
608+
);
609+
let values = collect_i64_column(&ctx, &sql, "id_plus").await;
610+
assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}");
611+
}
612+
613+
/// Direct ORDER BY UDF with an unaliased computed projection relies on the
614+
/// outer projection rebuilding by schema name rather than by raw expression.
615+
#[tokio::test]
616+
async fn exec_split_provider_order_by_udf_with_computed_expr() {
617+
let ctx = make_split_provider_ctx("items::vector").await;
618+
let sql = format!(
619+
"SELECT CAST(id + 1 AS BIGINT) FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2"
620+
);
621+
let values = collect_first_i64_column(&ctx, &sql).await;
622+
assert_eq!(values, vec![2, 3], "unexpected computed values: {values:?}");
623+
}
624+
533625
/// SELECT * with distance UDF — should fall back to UDF brute-force
534626
/// (since vector column is not in lookup provider schema).
535627
#[tokio::test]

0 commit comments

Comments
 (0)