Skip to content

Commit 3d8c7a2

Browse files
authored
fix(rule): index ORDER BY-only distance with WHERE (#26)
A k-NN query whose distance appears only in ORDER BY (not the SELECT list) silently fell back to brute-force whenever a WHERE clause was present. With a Filter, DataFusion materializes the raw vector column in an intermediate projection to feed the Sort, then trims it with an outer projection: Projection: id <- real output (no vector) Sort: l2_distance(vector, lit) Projection: id, vector <- vector only feeds the Sort Filter: ... TableScan The Sort-anchored match judged producibility on the inner projection, saw the vector column the node cannot produce, and declined. Without a WHERE clause projection pushdown eats the intermediate projection, so the passthrough arm fired and the gap went unnoticed. Extend the Projection-anchored arm to recognize the trimmed shape and judge producibility on the OUTER projection (the query's real output). SELECT * / SELECT id, vector still fall back (#508 behavior preserved, re-tested); the aliased-distance shape still rewrites via the Sort arm (the new arm finds no distance in the outer exprs and declines). New tests use a ducklake-style rowid Int64 addressing key, covering the key-column-agnostic path alongside the parquet-style _key fixtures. test_bare_orderby_with_where_rewrites fails on the unfixed rule.
1 parent 2c65279 commit 3d8c7a2

2 files changed

Lines changed: 273 additions & 8 deletions

File tree

src/rule.rs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
// Filter(predicate) ← WHERE clause absorbed
1212
// TableScan(name)
1313
//
14+
// Patterns matched (TopDown, Projection node):
15+
//
16+
// Projection([output cols]) ← SELECT list without the distance
17+
// Sort(l2_distance(vector, lit), fetch=k) ← distance inline in ORDER BY
18+
// [Projection([output cols + vector])] ← optional; DataFusion materializes
19+
// [Filter(predicate)] the vector only to feed the Sort
20+
// TableScan(name)
21+
//
22+
// In the Projection-anchored shape, producibility is judged on the OUTER
23+
// projection (the query's real output): `SELECT id … ORDER BY l2_distance(…)`
24+
// rewrites, while `SELECT *` / `SELECT id, vector` still fall back because the
25+
// node cannot produce the vector column (issue #508).
26+
//
1427
// When a Filter node is present its predicate is stored in USearchNode.filters.
1528
// The physical planner then runs adaptive filtered search:
1629
// - high selectivity → usearch::Index::filtered_search (in-graph filtering)
@@ -82,15 +95,35 @@ impl USearchRule {
8295
let LogicalPlan::Sort(sort) = outer.input.as_ref() else {
8396
return None;
8497
};
85-
// Only the passthrough shape; the remap shape (projection *below*
86-
// the Sort) is handled when we visit the Sort above.
87-
if !matches!(
88-
sort.input.as_ref(),
89-
LogicalPlan::TableScan(_) | LogicalPlan::Filter(_)
90-
) {
91-
return None;
98+
match sort.input.as_ref() {
99+
// Passthrough shape: Sort rests directly on the scan.
100+
LogicalPlan::TableScan(_) | LogicalPlan::Filter(_) => {
101+
self.build_rewrite(sort, &outer.expr, sort.input.as_ref())
102+
}
103+
// Trimmed shape: `SELECT id … ORDER BY l2_distance(vec, …)`
104+
// with the distance NOT in the SELECT list. DataFusion
105+
// materializes the raw vector column in an intermediate
106+
// projection purely to feed the Sort, then trims it with
107+
// this outer projection. Producibility must be judged on
108+
// the OUTER (real output) columns — the inner projection's
109+
// vector column never reaches the user. The Sort visit
110+
// would wrongly decline this shape (it sees the vector
111+
// among the inner projection's outputs and the node cannot
112+
// produce it). When the distance is instead aliased inside
113+
// the inner projection (`SELECT …, l2_distance(…) AS d …
114+
// ORDER BY d`), `find_distance_info` finds no distance in
115+
// the outer exprs, `build_rewrite` declines here, and the
116+
// Sort visit handles it exactly as before.
117+
LogicalPlan::Projection(inner)
118+
if matches!(
119+
inner.input.as_ref(),
120+
LogicalPlan::TableScan(_) | LogicalPlan::Filter(_)
121+
) =>
122+
{
123+
self.build_rewrite(sort, &outer.expr, inner.input.as_ref())
124+
}
125+
_ => None,
92126
}
93-
self.build_rewrite(sort, &outer.expr, sort.input.as_ref())
94127
}
95128

96129
_ => None,

tests/orderby_distance_trimmed.rs

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
// tests/orderby_distance_trimmed.rs — Regression tests for the "trimmed"
2+
// k-NN shape: `ORDER BY distance_fn(vector, lit)` where the distance is NOT in
3+
// the SELECT list and the query has a WHERE clause.
4+
//
5+
// With a Filter present, DataFusion materializes the raw vector column in an
6+
// intermediate projection between the Sort and the Filter (it is needed to
7+
// evaluate the inline ORDER BY expression), then trims it with an outer
8+
// projection:
9+
//
10+
// Projection: id ← real output (no vector)
11+
// Sort: l2_distance(vector, lit)
12+
// Projection: id, vector ← vector materialized only for the Sort
13+
// Filter: label = '…'
14+
// TableScan
15+
//
16+
// The Sort-anchored match judges producibility on the INNER projection — which
17+
// contains the vector the node cannot produce — and would wrongly fall back.
18+
// The Projection-anchored match must recognize this shape and judge
19+
// producibility on the OUTER projection instead (issue: ORDER-BY-only distance
20+
// silently losing the index whenever a WHERE clause is present).
21+
//
22+
// Unlike tests/vector_col_projection.rs, the fixtures here use a ducklake-style
23+
// addressing key — `rowid: Int64` — rather than the parquet-style
24+
// `_key: UInt64`, so the key-column-agnostic path is covered too.
25+
26+
use std::sync::Arc;
27+
28+
use arrow_schema::{DataType, Field, Schema};
29+
use datafusion::logical_expr::LogicalPlan;
30+
use datafusion::prelude::SessionContext;
31+
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
32+
33+
use datafusion_vector_search_ext::{HashKeyProvider, USearchNode, USearchRegistry, register_all};
34+
35+
/// The user-visible table: addressing key absent, vector column present.
36+
fn table_schema() -> Arc<Schema> {
37+
Arc::new(Schema::new(vec![
38+
Field::new("id", DataType::Int32, false),
39+
Field::new("label", DataType::Utf8, true),
40+
Field::new(
41+
"embedding",
42+
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
43+
false,
44+
),
45+
]))
46+
}
47+
48+
/// Sidecar/lookup schema: ducklake-style `rowid` key + non-vector columns.
49+
/// The vector column is excluded — exactly as the SQLite sidecar stores it.
50+
fn lookup_schema() -> Arc<Schema> {
51+
Arc::new(Schema::new(vec![
52+
Field::new("rowid", DataType::Int64, false),
53+
Field::new("id", DataType::Int32, false),
54+
Field::new("label", DataType::Utf8, true),
55+
]))
56+
}
57+
58+
/// Scan-provider schema: full column set including the vector, with the
59+
/// `rowid` key — mirrors the snapshot-pinned DuckLake scan provider.
60+
fn scan_schema() -> Arc<Schema> {
61+
Arc::new(Schema::new(vec![
62+
Field::new("id", DataType::Int32, false),
63+
Field::new("label", DataType::Utf8, true),
64+
Field::new(
65+
"embedding",
66+
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
67+
false,
68+
),
69+
Field::new("rowid", DataType::Int64, false),
70+
]))
71+
}
72+
73+
fn make_index() -> Arc<Index> {
74+
let options = IndexOptions {
75+
dimensions: 4,
76+
metric: MetricKind::L2sq,
77+
quantization: ScalarKind::F32,
78+
..Default::default()
79+
};
80+
Arc::new(Index::new(&options).expect("usearch Index::new failed"))
81+
}
82+
83+
async fn make_ctx() -> SessionContext {
84+
let scan_provider = Arc::new(
85+
HashKeyProvider::try_new(scan_schema(), vec![], "rowid")
86+
.expect("scan HashKeyProvider::try_new failed"),
87+
);
88+
let lookup_provider = Arc::new(
89+
HashKeyProvider::try_new(lookup_schema(), vec![], "rowid")
90+
.expect("lookup HashKeyProvider::try_new failed"),
91+
);
92+
93+
let reg = USearchRegistry::new();
94+
reg.add(
95+
"items::embedding",
96+
make_index(),
97+
scan_provider,
98+
lookup_provider,
99+
"rowid",
100+
MetricKind::L2sq,
101+
ScalarKind::F32,
102+
)
103+
.expect("USearchRegistry::add failed");
104+
let registry = reg.into_arc();
105+
106+
let ctx = SessionContext::default();
107+
register_all(&ctx, registry).expect("register_all failed");
108+
109+
let table = Arc::new(
110+
HashKeyProvider::try_new(table_schema(), vec![], "id")
111+
.expect("table HashKeyProvider::try_new failed"),
112+
);
113+
ctx.register_table("items", table)
114+
.expect("register_table failed");
115+
ctx
116+
}
117+
118+
fn contains_usearch_node(plan: &LogicalPlan) -> bool {
119+
if let LogicalPlan::Extension(ext) = plan
120+
&& ext.node.as_any().downcast_ref::<USearchNode>().is_some()
121+
{
122+
return true;
123+
}
124+
plan.inputs().iter().any(|c| contains_usearch_node(c))
125+
}
126+
127+
const Q: &str = "ARRAY[0.1, 0.2, 0.3, 0.4]";
128+
129+
/// The shape this file exists for: distance only in ORDER BY, WHERE present.
130+
/// The vector appears in DataFusion's intermediate projection but not in the
131+
/// query output — the rule must use the index.
132+
#[tokio::test]
133+
async fn test_bare_orderby_with_where_rewrites() {
134+
let ctx = make_ctx().await;
135+
let sql = format!(
136+
"SELECT id FROM items WHERE label = 'x' \
137+
ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"
138+
);
139+
let plan = ctx
140+
.sql(&sql)
141+
.await
142+
.expect("SQL analysis failed")
143+
.into_optimized_plan()
144+
.expect("optimization must not error");
145+
assert!(
146+
contains_usearch_node(&plan),
147+
"vector not in output → rule must use the index despite the WHERE-induced \
148+
intermediate projection\nPlan: {plan:?}"
149+
);
150+
}
151+
152+
/// Multiple non-vector output columns, same shape.
153+
#[tokio::test]
154+
async fn test_bare_orderby_with_where_multiple_columns_rewrites() {
155+
let ctx = make_ctx().await;
156+
let sql = format!(
157+
"SELECT id, label FROM items WHERE label LIKE 'x%' \
158+
ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 5"
159+
);
160+
let plan = ctx
161+
.sql(&sql)
162+
.await
163+
.expect("SQL analysis failed")
164+
.into_optimized_plan()
165+
.expect("optimization must not error");
166+
assert!(
167+
contains_usearch_node(&plan),
168+
"all output columns producible → rule must use the index\nPlan: {plan:?}"
169+
);
170+
}
171+
172+
/// `SELECT *` with WHERE: output includes the vector → must still fall back.
173+
#[tokio::test]
174+
async fn test_select_star_with_where_still_falls_back() {
175+
let ctx = make_ctx().await;
176+
let sql = format!(
177+
"SELECT * FROM items WHERE label = 'x' \
178+
ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"
179+
);
180+
let plan = ctx
181+
.sql(&sql)
182+
.await
183+
.expect("SQL analysis failed")
184+
.into_optimized_plan()
185+
.expect("optimization must not error when the vector column is in the output");
186+
assert!(
187+
!contains_usearch_node(&plan),
188+
"vector column in output → rule must fall back, WHERE or not\nPlan: {plan:?}"
189+
);
190+
}
191+
192+
/// Explicit vector column in the output with WHERE: must still fall back.
193+
#[tokio::test]
194+
async fn test_select_vector_with_where_still_falls_back() {
195+
let ctx = make_ctx().await;
196+
let sql = format!(
197+
"SELECT id, embedding FROM items WHERE label = 'x' \
198+
ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"
199+
);
200+
let plan = ctx
201+
.sql(&sql)
202+
.await
203+
.expect("SQL analysis failed")
204+
.into_optimized_plan()
205+
.expect("optimization must not error when the vector column is in the output");
206+
assert!(
207+
!contains_usearch_node(&plan),
208+
"vector column in output → rule must fall back, WHERE or not\nPlan: {plan:?}"
209+
);
210+
}
211+
212+
/// The canonical aliased-distance shape with WHERE keeps rewriting via the
213+
/// Sort-anchored match (regression guard: the new Projection-anchored arm must
214+
/// decline it cleanly and leave it to the Sort visit).
215+
#[tokio::test]
216+
async fn test_aliased_distance_with_where_still_rewrites() {
217+
let ctx = make_ctx().await;
218+
let sql = format!(
219+
"SELECT id, l2_distance(embedding, {Q}) AS dist FROM items \
220+
WHERE label = 'x' ORDER BY dist ASC LIMIT 2"
221+
);
222+
let plan = ctx
223+
.sql(&sql)
224+
.await
225+
.expect("SQL analysis failed")
226+
.into_optimized_plan()
227+
.expect("optimization failed");
228+
assert!(
229+
contains_usearch_node(&plan),
230+
"aliased-distance shape must keep rewriting\nPlan: {plan:?}"
231+
);
232+
}

0 commit comments

Comments
 (0)