Skip to content

Commit 4605f0e

Browse files
authored
refactor(registry): split provider into scan_provider + lookup_provider (#6)
* refactor(registry): split provider into scan_provider + lookup_provider Separate RegisteredTable.provider into scan_provider (TableProvider for WHERE evaluation) and lookup_provider (PointLookupProvider for key-based fetch), preparing for Parquet-native adaptive filtering. * fix(planner): use separate key column indices for scan and lookup schemas The scan_provider and lookup_provider may have different schemas, so the key column can be at different indices. Use scan_key_col_idx when reading scan batches and lookup_key_col_idx for attach_distances on fetched rows. * style(planner): fix rustfmt formatting * fix(registry): validate key column in scan_provider schema at registration Add scan_provider.schema().index_of(key_col) guard in add_with_config to catch misconfigured registrations early instead of at query execution time.
1 parent 3c3b8cf commit 4605f0e

7 files changed

Lines changed: 128 additions & 27 deletions

File tree

src/lookup.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ use datafusion::physical_plan::ExecutionPlan;
2525

2626
// ── Trait ─────────────────────────────────────────────────────────────────────
2727

28-
/// A [`TableProvider`] that guarantees efficient row retrieval by primary key.
28+
/// Trait for efficient row retrieval by primary key.
2929
///
3030
/// Implementors provide O(k) or O(k log N) row lookups — no full-table scan.
3131
/// The `USearchRegistry` requires this trait instead of a bare `TableProvider`
3232
/// to enforce the performance contract at registration time.
3333
///
3434
/// # Contract
3535
///
36+
/// - `schema()` MUST return the Arrow schema of the rows returned by
37+
/// `fetch_by_keys` (when `projection` is `None`).
3638
/// - `fetch_by_keys` MUST return only rows whose key column value is in `keys`.
3739
/// - Keys not found in the table are silently omitted — not an error.
3840
/// - Returned batches must use a schema consistent with `self.schema()`. When
@@ -46,6 +48,8 @@ use datafusion::physical_plan::ExecutionPlan;
4648
/// ```rust,ignore
4749
/// #[async_trait]
4850
/// impl PointLookupProvider for MyEngineTable {
51+
/// fn schema(&self) -> SchemaRef { self.schema.clone() }
52+
///
4953
/// async fn fetch_by_keys(
5054
/// &self,
5155
/// keys: &[u64],
@@ -60,7 +64,10 @@ use datafusion::physical_plan::ExecutionPlan;
6064
/// }
6165
/// ```
6266
#[async_trait]
63-
pub trait PointLookupProvider: TableProvider + Send + Sync {
67+
pub trait PointLookupProvider: Send + Sync {
68+
/// Arrow schema of the rows this provider returns (without `_distance`).
69+
fn schema(&self) -> SchemaRef;
70+
6471
async fn fetch_by_keys(
6572
&self,
6673
keys: &[u64],
@@ -145,6 +152,10 @@ impl fmt::Debug for HashKeyProvider {
145152

146153
#[async_trait]
147154
impl PointLookupProvider for HashKeyProvider {
155+
fn schema(&self) -> SchemaRef {
156+
self.schema.clone()
157+
}
158+
148159
async fn fetch_by_keys(
149160
&self,
150161
keys: &[u64],

src/parquet_provider.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ async fn load_metadata_cache(
148148

149149
#[async_trait]
150150
impl PointLookupProvider for ParquetLookupProvider {
151+
fn schema(&self) -> SchemaRef {
152+
self.schema.clone()
153+
}
154+
151155
async fn fetch_by_keys(
152156
&self,
153157
keys: &[u64],

src/planner.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ impl ExtensionPlanner for USearchExecPlanner {
145145
let provider_scan = if !node.filters.is_empty() {
146146
Some(
147147
registered
148-
.provider
148+
.scan_provider
149149
.scan(session_state, None, &[], None)
150150
.await?,
151151
)
@@ -313,7 +313,7 @@ async fn usearch_execute(
313313
.collect();
314314

315315
let data_batches = registered
316-
.provider
316+
.lookup_provider
317317
.fetch_by_keys(&matches.keys, &params.key_col, None)
318318
.await?;
319319

@@ -352,8 +352,16 @@ async fn adaptive_filtered_execute(
352352
scan_plan: Arc<dyn ExecutionPlan>,
353353
task_ctx: Arc<TaskContext>,
354354
) -> Result<Vec<RecordBatch>> {
355-
let provider_schema = registered.provider.schema();
356-
let key_col_idx = provider_key_col_idx(registered)?;
355+
let provider_schema = registered.scan_provider.schema();
356+
// Key column index in scan_provider schema — used when reading scan batches.
357+
let scan_key_col_idx = provider_schema.index_of(&registered.key_col).map_err(|_| {
358+
DataFusionError::Execution(format!(
359+
"USearchExecPlanner: key column '{}' not found in scan provider schema",
360+
registered.key_col
361+
))
362+
})?;
363+
// Key column index in lookup_provider schema — used by attach_distances.
364+
let lookup_key_col_idx = provider_key_col_idx(registered)?;
357365
let vec_col_idx = provider_schema.index_of(&params.vector_col).ok();
358366
let has_vec_col = vec_col_idx.is_some();
359367

@@ -371,7 +379,7 @@ async fn adaptive_filtered_execute(
371379
while let Some(batch_result) = stream.next().await {
372380
let batch = batch_result?;
373381
let mask = evaluate_filters(&params.physical_filters, &batch)?;
374-
let keys = extract_keys_as_u64(batch.column(key_col_idx).as_ref())?;
382+
let keys = extract_keys_as_u64(batch.column(scan_key_col_idx).as_ref())?;
375383

376384
for row_idx in 0..batch.num_rows() {
377385
if !mask.is_null(row_idx)
@@ -445,12 +453,16 @@ async fn adaptive_filtered_execute(
445453
let top_keys: Vec<u64> = top_k.iter().map(|(k, _)| *k).collect();
446454

447455
let data_batches = registered
448-
.provider
456+
.lookup_provider
449457
.fetch_by_keys(&top_keys, &params.key_col, None)
450458
.await?;
451459

452-
let result_batches =
453-
attach_distances(data_batches, key_col_idx, &key_to_dist, &params.schema)?;
460+
let result_batches = attach_distances(
461+
data_batches,
462+
lookup_key_col_idx,
463+
&key_to_dist,
464+
&params.schema,
465+
)?;
454466

455467
tracing::Span::current().record(
456468
"usearch.result_count",
@@ -486,12 +498,16 @@ async fn adaptive_filtered_execute(
486498
.collect();
487499

488500
let data_batches = registered
489-
.provider
501+
.lookup_provider
490502
.fetch_by_keys(&matches.keys, &params.key_col, None)
491503
.await?;
492504

493-
let result_batches =
494-
attach_distances(data_batches, key_col_idx, &key_to_dist, &params.schema)?;
505+
let result_batches = attach_distances(
506+
data_batches,
507+
lookup_key_col_idx,
508+
&key_to_dist,
509+
&params.schema,
510+
)?;
495511

496512
tracing::Span::current().record(
497513
"usearch.result_count",
@@ -729,15 +745,15 @@ fn heap_select_top_k(pairs: &mut [(u64, f32)], k: usize) -> Vec<(u64, f32)> {
729745
result
730746
}
731747

732-
/// Index of the key column in the provider schema.
748+
/// Index of the key column in the lookup provider schema.
733749
fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result<usize> {
734750
registered
735-
.provider
751+
.lookup_provider
736752
.schema()
737753
.index_of(&registered.key_col)
738754
.map_err(|_| {
739755
DataFusionError::Execution(format!(
740-
"USearchExecPlanner: key column '{}' not found in provider schema",
756+
"USearchExecPlanner: key column '{}' not found in lookup provider schema",
741757
registered.key_col
742758
))
743759
})

src/registry.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use datafusion::common::Result;
88
use datafusion::error::DataFusionError;
99
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
1010

11+
use datafusion::catalog::TableProvider;
12+
1113
use crate::lookup::PointLookupProvider;
1214

1315
// ── USearchIndexConfig ────────────────────────────────────────────────────────
@@ -177,7 +179,10 @@ impl Default for USearchTableConfig {
177179

178180
pub struct RegisteredTable {
179181
pub index: Arc<Index>,
180-
pub provider: Arc<dyn PointLookupProvider>,
182+
/// Scan provider for WHERE evaluation and low-selectivity Parquet-native path.
183+
pub scan_provider: Arc<dyn TableProvider>,
184+
/// Lookup provider for efficient key-based row fetch (e.g. SQLite).
185+
pub lookup_provider: Arc<dyn PointLookupProvider>,
181186
pub key_col: String,
182187
pub metric: MetricKind,
183188
/// Native scalar type of the vector column. Determines which typed search
@@ -213,31 +218,36 @@ impl USearchRegistry {
213218
/// [`USearchTableConfig::default()`] (ef_search=64, threshold=5%).
214219
///
215220
/// - `index` — must already be loaded / populated.
216-
/// - `provider` — must implement [`PointLookupProvider`].
221+
/// - `scan_provider` — [`TableProvider`] used for WHERE evaluation and
222+
/// low-selectivity Parquet-native scanning.
223+
/// - `lookup_provider` — [`PointLookupProvider`] for O(k) key-based fetch.
217224
/// [`HashKeyProvider`] is the bundled in-memory implementation.
218-
/// For production, implement the trait on your storage engine's table type.
219-
/// - `key_col` — column in `provider.schema()` that stores the USearch key
220-
/// (`u64`). Supported Arrow types: `UInt64`, `Int64`, `UInt32`, `Int32`.
225+
/// - `key_col` — column in `lookup_provider.schema()` that stores the
226+
/// USearch key (`u64`). Supported Arrow types: `UInt64`, `Int64`,
227+
/// `UInt32`, `Int32`.
221228
/// - `metric` — must match how the index was built. The optimizer rule
222229
/// validates this and refuses to rewrite on mismatch.
223230
/// - `scalar_kind` — native element type of the vector column (`F32` or
224231
/// `F64`). Controls which typed search method the planner dispatches to.
225232
///
226233
/// [`add_with_config`]: USearchRegistry::add_with_config
227234
/// [`HashKeyProvider`]: crate::lookup::HashKeyProvider
235+
#[allow(clippy::too_many_arguments)]
228236
pub fn add(
229237
&self,
230238
name: &str,
231239
index: Arc<Index>,
232-
provider: Arc<dyn PointLookupProvider>,
240+
scan_provider: Arc<dyn TableProvider>,
241+
lookup_provider: Arc<dyn PointLookupProvider>,
233242
key_col: &str,
234243
metric: MetricKind,
235244
scalar_kind: ScalarKind,
236245
) -> Result<()> {
237246
self.add_with_config(
238247
name,
239248
index,
240-
provider,
249+
scan_provider,
250+
lookup_provider,
241251
key_col,
242252
metric,
243253
scalar_kind,
@@ -254,7 +264,8 @@ impl USearchRegistry {
254264
&self,
255265
name: &str,
256266
index: Arc<Index>,
257-
provider: Arc<dyn PointLookupProvider>,
267+
scan_provider: Arc<dyn TableProvider>,
268+
lookup_provider: Arc<dyn PointLookupProvider>,
258269
key_col: &str,
259270
metric: MetricKind,
260271
scalar_kind: ScalarKind,
@@ -263,11 +274,17 @@ impl USearchRegistry {
263274
// Set ef_search once, here, before any query touches the index.
264275
index.change_expansion_search(config.expansion_search);
265276

266-
let data_schema = provider.schema();
277+
let data_schema = lookup_provider.schema();
267278

268279
let _ = data_schema.index_of(key_col).map_err(|_| {
269280
DataFusionError::Execution(format!(
270-
"USearchRegistry: key column '{key_col}' not found in table '{name}' schema"
281+
"USearchRegistry: key column '{key_col}' not found in lookup provider schema for table '{name}'"
282+
))
283+
})?;
284+
285+
let _ = scan_provider.schema().index_of(key_col).map_err(|_| {
286+
DataFusionError::Execution(format!(
287+
"USearchRegistry: key column '{key_col}' not found in scan provider schema for table '{name}'"
271288
))
272289
})?;
273290

@@ -286,7 +303,8 @@ impl USearchRegistry {
286303
name.to_string(),
287304
Arc::new(RegisteredTable {
288305
index,
289-
provider,
306+
scan_provider,
307+
lookup_provider,
290308
key_col: key_col.to_string(),
291309
metric,
292310
scalar_kind,

src/sqlite_provider.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ impl SqliteLookupProvider {
177177

178178
#[async_trait]
179179
impl PointLookupProvider for SqliteLookupProvider {
180+
fn schema(&self) -> SchemaRef {
181+
self.schema.clone()
182+
}
183+
180184
async fn fetch_by_keys(
181185
&self,
182186
keys: &[u64],

tests/execution.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ async fn make_exec_ctx(reg_key: &str) -> SessionContext {
106106
reg_key,
107107
make_populated_index(),
108108
provider.clone(),
109+
provider.clone(),
109110
"id",
110111
MetricKind::L2sq,
111112
ScalarKind::F32,
@@ -256,6 +257,51 @@ async fn exec_qualified_where_order_by_alias() {
256257
assert_eq!(ids[0], 1, "closest alpha row must be row 1\nids: {ids:?}");
257258
}
258259

260+
// ═══════════════════════════════════════════════════════════════════════════════
261+
// Registration validation
262+
// ═══════════════════════════════════════════════════════════════════════════════
263+
264+
/// Registration must fail when scan_provider schema is missing the key column.
265+
#[tokio::test]
266+
async fn reg_scan_provider_missing_key_col_errors() {
267+
// scan_provider schema: only "label" and "vector" — no "id".
268+
let scan_schema = Arc::new(Schema::new(vec![
269+
Field::new("label", DataType::Utf8, false),
270+
Field::new(
271+
"vector",
272+
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
273+
false,
274+
),
275+
]));
276+
let scan_provider =
277+
Arc::new(HashKeyProvider::try_new(scan_schema, vec![], "label").expect("HashKeyProvider"));
278+
279+
// lookup_provider has "id".
280+
let lookup_schema = exec_schema();
281+
let lookup_provider =
282+
Arc::new(HashKeyProvider::try_new(lookup_schema, vec![], "id").expect("HashKeyProvider"));
283+
284+
let reg = USearchRegistry::new();
285+
let result = reg.add(
286+
"test::vector",
287+
make_populated_index(),
288+
scan_provider,
289+
lookup_provider,
290+
"id",
291+
MetricKind::L2sq,
292+
ScalarKind::F32,
293+
);
294+
assert!(
295+
result.is_err(),
296+
"registration must fail when scan_provider lacks key column"
297+
);
298+
let msg = result.unwrap_err().to_string();
299+
assert!(
300+
msg.contains("scan provider"),
301+
"error must mention scan provider: {msg}"
302+
);
303+
}
304+
259305
/// Qualified table, WHERE clause, ORDER BY UDF directly.
260306
#[tokio::test]
261307
async fn exec_qualified_where_order_by_udf() {

tests/optimizer_rule.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ async fn make_ctx(metric: MetricKind) -> SessionContext {
6262
"items::vector",
6363
make_index(metric),
6464
provider.clone(),
65+
provider.clone(),
6566
"id",
6667
metric,
6768
ScalarKind::F32,
@@ -347,6 +348,7 @@ async fn make_ctx_qualified(metric: MetricKind) -> SessionContext {
347348
"datafusion::public::items::vector",
348349
make_index(metric),
349350
provider.clone(),
351+
provider.clone(),
350352
"id",
351353
metric,
352354
ScalarKind::F32,

0 commit comments

Comments
 (0)