Skip to content

Commit b40666d

Browse files
committed
fix(rule): preserve hidden distance for sort
Keep _distance in an inner projection when ORDER BY uses a vector\ndistance expression that is not part of the final select list.\n\nThis fixes split-provider execution for queries like SELECT id ORDER\nBY l2_distance(vector, ARRAY[...]) LIMIT k while preserving the final\noutput schema. Add an execution test for the direct ORDER BY shape to\ncover the production case.
1 parent f21e8a4 commit b40666d

2 files changed

Lines changed: 81 additions & 18 deletions

File tree

src/rule.rs

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
//
1919
// Replacement:
2020
//
21-
// Sort(fetch=k) ← kept (sort order)
22-
// Projection([col(a), col(b), col("_distance").alias("dist")])
23-
// USearchNode ← executes ANN
21+
// Projection([final output cols])
22+
// Sort(fetch=k)
23+
// Projection([final output cols + optional hidden _distance])
24+
// USearchNode
2425

2526
use std::collections::HashMap;
2627
use std::sync::Arc;
@@ -151,25 +152,33 @@ impl USearchRule {
151152
node: Arc::new(node) as Arc<dyn UserDefinedLogicalNode>,
152153
}));
153154

154-
// Build Projection over USearchNode matching the original output schema.
155+
// Build the final user-visible projection over USearchNode output.
155156
let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance");
156-
let new_proj_exprs = if proj_exprs_slice.is_empty() {
157+
let final_proj_exprs = if proj_exprs_slice.is_empty() {
157158
passthrough_projection(&vsn_df_schema, &table_ref)
158159
} else {
159160
remap_projections(proj_exprs_slice, dist_alias_str, &table_ref)
160161
};
161-
let new_proj = Projection::try_new(new_proj_exprs, node_plan).ok()?;
162-
163-
// Keep the Sort node so DataFusion handles ordering by _distance / dist.
164-
// USearch returns results in arbitrary (internal) order when the underlying
165-
// data is fetched from the TableProvider.
166-
Some(LogicalPlan::Sort(
167-
datafusion::logical_expr::logical_plan::Sort {
168-
expr: sort.expr.clone(),
169-
input: Arc::new(LogicalPlan::Projection(new_proj)),
170-
fetch: sort.fetch,
171-
},
172-
))
162+
let remapped_sort_exprs = remap_sort_exprs(&sort.expr, dist_alias.as_deref());
163+
let needs_hidden_distance = remapped_sort_exprs.iter().any(|e| {
164+
matches!(&e.expr, Expr::Column(c) if c.relation.is_none() && c.name == "_distance")
165+
}) && !projection_exposes_name(&final_proj_exprs, "_distance");
166+
167+
let mut sort_input_exprs = final_proj_exprs.clone();
168+
if needs_hidden_distance {
169+
sort_input_exprs.push(col("_distance"));
170+
}
171+
172+
let sort_input = Projection::try_new(sort_input_exprs, node_plan).ok()?;
173+
let sorted = LogicalPlan::Sort(datafusion::logical_expr::logical_plan::Sort {
174+
expr: remapped_sort_exprs,
175+
input: Arc::new(LogicalPlan::Projection(sort_input)),
176+
fetch: sort.fetch,
177+
});
178+
179+
let outer_proj_exprs = build_outer_projection(&final_proj_exprs);
180+
let outer_proj = Projection::try_new(outer_proj_exprs, Arc::new(sorted)).ok()?;
181+
Some(LogicalPlan::Projection(outer_proj))
173182
}
174183
}
175184

@@ -283,7 +292,11 @@ fn dist_type_matches_metric(dist_type: &DistanceType, metric: MetricKind) -> boo
283292
}
284293

285294
fn is_distance_expr(expr: &Expr) -> bool {
286-
matches!(expr, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
295+
let inner = match expr {
296+
Expr::Alias(a) => a.expr.as_ref(),
297+
other => other,
298+
};
299+
matches!(inner, Expr::ScalarFunction(sf) if is_dist_udf_name(sf.func.name()))
287300
}
288301

289302
fn try_extract_distance(expr: &Expr) -> Option<(String, String, Vec<f64>)> {
@@ -322,6 +335,45 @@ fn remap_projections(
322335
.collect()
323336
}
324337

338+
fn remap_sort_exprs(
339+
sort_exprs: &[datafusion::logical_expr::SortExpr],
340+
dist_alias_name: Option<&str>,
341+
) -> Vec<datafusion::logical_expr::SortExpr> {
342+
sort_exprs
343+
.iter()
344+
.map(|sort_expr| {
345+
let remapped_expr = match &sort_expr.expr {
346+
Expr::Column(c) if Some(c.name.as_str()) == dist_alias_name => col(c.name.as_str()),
347+
expr if is_distance_expr(expr) => col("_distance"),
348+
other => other.clone(),
349+
};
350+
datafusion::logical_expr::SortExpr {
351+
expr: remapped_expr,
352+
asc: sort_expr.asc,
353+
nulls_first: sort_expr.nulls_first,
354+
}
355+
})
356+
.collect()
357+
}
358+
359+
fn projection_exposes_name(exprs: &[Expr], name: &str) -> bool {
360+
exprs.iter().any(|expr| match expr {
361+
Expr::Alias(a) => a.name == name,
362+
Expr::Column(c) => c.name == name,
363+
_ => false,
364+
})
365+
}
366+
367+
fn build_outer_projection(exprs: &[Expr]) -> Vec<Expr> {
368+
exprs.iter()
369+
.filter_map(|expr| match expr {
370+
Expr::Alias(a) => Some(col(a.name.as_str())),
371+
Expr::Column(c) => Some(Expr::Column(c.clone())),
372+
_ => None,
373+
})
374+
.collect()
375+
}
376+
325377
/// Build a passthrough Projection for SELECT * queries (no original Projection node).
326378
/// Projects only the original table columns (not `_distance`) so the output schema
327379
/// matches the original Sort schema. The Sort re-evaluates the distance UDF expression

tests/execution.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,17 @@ async fn exec_split_provider_select_specific_columns() {
530530
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
531531
}
532532

533+
/// SELECT specific columns without projecting the distance expression.
534+
/// This is the production shape behind `vector_distance(...)`.
535+
#[tokio::test]
536+
async fn exec_split_provider_order_by_udf_direct() {
537+
let ctx = make_split_provider_ctx("items::vector").await;
538+
let sql = format!("SELECT id FROM items ORDER BY l2_distance(vector, {Q}) ASC LIMIT 2");
539+
let ids = collect_ids(&ctx, &sql).await;
540+
assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}");
541+
assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}");
542+
}
543+
533544
/// SELECT * with distance UDF — should fall back to UDF brute-force
534545
/// (since vector column is not in lookup provider schema).
535546
#[tokio::test]

0 commit comments

Comments
 (0)