Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ use datafusion::physical_plan::ExecutionPlan;

// ── Trait ─────────────────────────────────────────────────────────────────────

/// A [`TableProvider`] that guarantees efficient row retrieval by primary key.
/// Trait for efficient row retrieval by primary key.
///
/// Implementors provide O(k) or O(k log N) row lookups — no full-table scan.
/// The `USearchRegistry` requires this trait instead of a bare `TableProvider`
/// to enforce the performance contract at registration time.
///
/// # Contract
///
/// - `schema()` MUST return the Arrow schema of the rows returned by
/// `fetch_by_keys` (when `projection` is `None`).
/// - `fetch_by_keys` MUST return only rows whose key column value is in `keys`.
/// - Keys not found in the table are silently omitted — not an error.
/// - Returned batches must use a schema consistent with `self.schema()`. When
Expand All @@ -46,6 +48,8 @@ use datafusion::physical_plan::ExecutionPlan;
/// ```rust,ignore
/// #[async_trait]
/// impl PointLookupProvider for MyEngineTable {
/// fn schema(&self) -> SchemaRef { self.schema.clone() }
///
/// async fn fetch_by_keys(
/// &self,
/// keys: &[u64],
Expand All @@ -60,7 +64,10 @@ use datafusion::physical_plan::ExecutionPlan;
/// }
/// ```
#[async_trait]
pub trait PointLookupProvider: TableProvider + Send + Sync {
pub trait PointLookupProvider: Send + Sync {
/// Arrow schema of the rows this provider returns (without `_distance`).
fn schema(&self) -> SchemaRef;

async fn fetch_by_keys(
&self,
keys: &[u64],
Expand Down Expand Up @@ -145,6 +152,10 @@ impl fmt::Debug for HashKeyProvider {

#[async_trait]
impl PointLookupProvider for HashKeyProvider {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}

async fn fetch_by_keys(
&self,
keys: &[u64],
Expand Down
4 changes: 4 additions & 0 deletions src/parquet_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ async fn load_metadata_cache(

#[async_trait]
impl PointLookupProvider for ParquetLookupProvider {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}

async fn fetch_by_keys(
&self,
keys: &[u64],
Expand Down
44 changes: 30 additions & 14 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl ExtensionPlanner for USearchExecPlanner {
let provider_scan = if !node.filters.is_empty() {
Some(
registered
.provider
.scan_provider
.scan(session_state, None, &[], None)
.await?,
)
Expand Down Expand Up @@ -313,7 +313,7 @@ async fn usearch_execute(
.collect();

let data_batches = registered
.provider
.lookup_provider
.fetch_by_keys(&matches.keys, &params.key_col, None)
.await?;

Expand Down Expand Up @@ -352,8 +352,16 @@ async fn adaptive_filtered_execute(
scan_plan: Arc<dyn ExecutionPlan>,
task_ctx: Arc<TaskContext>,
) -> Result<Vec<RecordBatch>> {
let provider_schema = registered.provider.schema();
let key_col_idx = provider_key_col_idx(registered)?;
let provider_schema = registered.scan_provider.schema();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 — key_col_idx from lookup_provider is applied to scan_provider batches

provider_schema is correctly taken from scan_provider here, but the very next line:

let key_col_idx = provider_key_col_idx(registered)?;

calls lookup_provider.schema().index_of(key_col). That index is then used at the scan-phase loop to read batch.column(key_col_idx) where batch comes from scan_plan (i.e. scan_provider). When the two providers have different schemas — the exact split this PR is designed to enable — the column positions can differ, silently reading the wrong column or panicking.

The tests don't catch this because both test registrations pass provider.clone() for both arguments, making the schemas identical.

Fix: derive two separate indices:

// for scan batches
let scan_key_col_idx = provider_schema
    .index_of(&registered.key_col)
    .map_err(|_| DataFusionError::Execution(format!(
        "key column '{}' not found in scan provider schema", registered.key_col
    )))?;

// for lookup batches passed to attach_distances
let lookup_key_col_idx = provider_key_col_idx(registered)?;

Use scan_key_col_idx in the scan loop and lookup_key_col_idx at the two attach_distances(…) call sites.

// Key column index in scan_provider schema — used when reading scan batches.
let scan_key_col_idx = provider_schema.index_of(&registered.key_col).map_err(|_| {
DataFusionError::Execution(format!(
"USearchExecPlanner: key column '{}' not found in scan provider schema",
registered.key_col
))
})?;
// Key column index in lookup_provider schema — used by attach_distances.
let lookup_key_col_idx = provider_key_col_idx(registered)?;
let vec_col_idx = provider_schema.index_of(&params.vector_col).ok();
let has_vec_col = vec_col_idx.is_some();

Expand All @@ -371,7 +379,7 @@ async fn adaptive_filtered_execute(
while let Some(batch_result) = stream.next().await {
let batch = batch_result?;
let mask = evaluate_filters(&params.physical_filters, &batch)?;
let keys = extract_keys_as_u64(batch.column(key_col_idx).as_ref())?;
let keys = extract_keys_as_u64(batch.column(scan_key_col_idx).as_ref())?;

for row_idx in 0..batch.num_rows() {
if !mask.is_null(row_idx)
Expand Down Expand Up @@ -445,12 +453,16 @@ async fn adaptive_filtered_execute(
let top_keys: Vec<u64> = top_k.iter().map(|(k, _)| *k).collect();

let data_batches = registered
.provider
.lookup_provider
.fetch_by_keys(&top_keys, &params.key_col, None)
.await?;

let result_batches =
attach_distances(data_batches, key_col_idx, &key_to_dist, &params.schema)?;
let result_batches = attach_distances(
data_batches,
lookup_key_col_idx,
&key_to_dist,
&params.schema,
)?;

tracing::Span::current().record(
"usearch.result_count",
Expand Down Expand Up @@ -486,12 +498,16 @@ async fn adaptive_filtered_execute(
.collect();

let data_batches = registered
.provider
.lookup_provider
.fetch_by_keys(&matches.keys, &params.key_col, None)
.await?;

let result_batches =
attach_distances(data_batches, key_col_idx, &key_to_dist, &params.schema)?;
let result_batches = attach_distances(
data_batches,
lookup_key_col_idx,
&key_to_dist,
&params.schema,
)?;

tracing::Span::current().record(
"usearch.result_count",
Expand Down Expand Up @@ -729,15 +745,15 @@ fn heap_select_top_k(pairs: &mut [(u64, f32)], k: usize) -> Vec<(u64, f32)> {
result
}

/// Index of the key column in the provider schema.
/// Index of the key column in the lookup provider schema.
fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result<usize> {
registered
.provider
.lookup_provider
.schema()
.index_of(&registered.key_col)
.map_err(|_| {
DataFusionError::Execution(format!(
"USearchExecPlanner: key column '{}' not found in provider schema",
"USearchExecPlanner: key column '{}' not found in lookup provider schema",
registered.key_col
))
})
Expand Down
40 changes: 29 additions & 11 deletions src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use datafusion::common::Result;
use datafusion::error::DataFusionError;
use usearch::{Index, IndexOptions, MetricKind, ScalarKind};

use datafusion::catalog::TableProvider;

use crate::lookup::PointLookupProvider;

// ── USearchIndexConfig ────────────────────────────────────────────────────────
Expand Down Expand Up @@ -177,7 +179,10 @@ impl Default for USearchTableConfig {

pub struct RegisteredTable {
pub index: Arc<Index>,
pub provider: Arc<dyn PointLookupProvider>,
/// Scan provider for WHERE evaluation and low-selectivity Parquet-native path.
pub scan_provider: Arc<dyn TableProvider>,
/// Lookup provider for efficient key-based row fetch (e.g. SQLite).
pub lookup_provider: Arc<dyn PointLookupProvider>,
pub key_col: String,
pub metric: MetricKind,
/// Native scalar type of the vector column. Determines which typed search
Expand Down Expand Up @@ -213,31 +218,36 @@ impl USearchRegistry {
/// [`USearchTableConfig::default()`] (ef_search=64, threshold=5%).
///
/// - `index` — must already be loaded / populated.
/// - `provider` — must implement [`PointLookupProvider`].
/// - `scan_provider` — [`TableProvider`] used for WHERE evaluation and
/// low-selectivity Parquet-native scanning.
/// - `lookup_provider` — [`PointLookupProvider`] for O(k) key-based fetch.
/// [`HashKeyProvider`] is the bundled in-memory implementation.
/// For production, implement the trait on your storage engine's table type.
/// - `key_col` — column in `provider.schema()` that stores the USearch key
/// (`u64`). Supported Arrow types: `UInt64`, `Int64`, `UInt32`, `Int32`.
/// - `key_col` — column in `lookup_provider.schema()` that stores the
/// USearch key (`u64`). Supported Arrow types: `UInt64`, `Int64`,
/// `UInt32`, `Int32`.
/// - `metric` — must match how the index was built. The optimizer rule
/// validates this and refuses to rewrite on mismatch.
/// - `scalar_kind` — native element type of the vector column (`F32` or
/// `F64`). Controls which typed search method the planner dispatches to.
///
/// [`add_with_config`]: USearchRegistry::add_with_config
/// [`HashKeyProvider`]: crate::lookup::HashKeyProvider
#[allow(clippy::too_many_arguments)]
pub fn add(
&self,
name: &str,
index: Arc<Index>,
provider: Arc<dyn PointLookupProvider>,
scan_provider: Arc<dyn TableProvider>,
lookup_provider: Arc<dyn PointLookupProvider>,
key_col: &str,
metric: MetricKind,
scalar_kind: ScalarKind,
) -> Result<()> {
self.add_with_config(
name,
index,
provider,
scan_provider,
lookup_provider,
key_col,
metric,
scalar_kind,
Expand All @@ -254,7 +264,8 @@ impl USearchRegistry {
&self,
name: &str,
index: Arc<Index>,
provider: Arc<dyn PointLookupProvider>,
scan_provider: Arc<dyn TableProvider>,
lookup_provider: Arc<dyn PointLookupProvider>,
key_col: &str,
metric: MetricKind,
scalar_kind: ScalarKind,
Expand All @@ -263,11 +274,17 @@ impl USearchRegistry {
// Set ef_search once, here, before any query touches the index.
index.change_expansion_search(config.expansion_search);

let data_schema = provider.schema();
let data_schema = lookup_provider.schema();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 — Missing validation of key_col in scan_provider.schema()

add_with_config checks that key_col exists in lookup_provider.schema() (lines 279–283), but never checks scan_provider.schema(). In adaptive_filtered_execute, the planner calls scan_provider.schema().index_of(&registered.key_col) at query execution time — so a scan_provider that lacks the key column silently passes registration and only fails on the first filtered query.

The whole motivation of this PR is to allow different providers for scan vs lookup; the tests mask the gap because they pass the same HashKeyProvider for both.

Fix: add an equivalent guard for scan_provider immediately after the existing lookup_provider check:

scan_provider
    .schema()
    .index_of(key_col)
    .map_err(|_| {
        DataFusionError::Execution(format!(
            "USearchRegistry: key column '{key_col}' not found in scan provider schema for table '{name}'"
        ))
    })?;


let _ = data_schema.index_of(key_col).map_err(|_| {
DataFusionError::Execution(format!(
"USearchRegistry: key column '{key_col}' not found in table '{name}' schema"
"USearchRegistry: key column '{key_col}' not found in lookup provider schema for table '{name}'"
))
})?;

let _ = scan_provider.schema().index_of(key_col).map_err(|_| {
DataFusionError::Execution(format!(
"USearchRegistry: key column '{key_col}' not found in scan provider schema for table '{name}'"
))
})?;

Expand All @@ -286,7 +303,8 @@ impl USearchRegistry {
name.to_string(),
Arc::new(RegisteredTable {
index,
provider,
scan_provider,
lookup_provider,
key_col: key_col.to_string(),
metric,
scalar_kind,
Expand Down
4 changes: 4 additions & 0 deletions src/sqlite_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ impl SqliteLookupProvider {

#[async_trait]
impl PointLookupProvider for SqliteLookupProvider {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}

async fn fetch_by_keys(
&self,
keys: &[u64],
Expand Down
46 changes: 46 additions & 0 deletions tests/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async fn make_exec_ctx(reg_key: &str) -> SessionContext {
reg_key,
make_populated_index(),
provider.clone(),
provider.clone(),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both scan_provider and lookup_provider are the same HashKeyProvider here (and in optimizer_rule.rs), so the tests never exercise the case where the two schemas differ — the exact scenario this PR was designed to enable. Consider adding a test that uses a minimal scan_provider schema (id + vector columns) alongside a lookup_provider schema that has additional non-vector columns, to verify scan_key_col_idx and lookup_key_col_idx are resolved independently and correctly under divergent schemas.

"id",
MetricKind::L2sq,
ScalarKind::F32,
Expand Down Expand Up @@ -256,6 +257,51 @@ async fn exec_qualified_where_order_by_alias() {
assert_eq!(ids[0], 1, "closest alpha row must be row 1\nids: {ids:?}");
}

// ═══════════════════════════════════════════════════════════════════════════════
// Registration validation
// ═══════════════════════════════════════════════════════════════════════════════

/// Registration must fail when scan_provider schema is missing the key column.
#[tokio::test]
async fn reg_scan_provider_missing_key_col_errors() {
// scan_provider schema: only "label" and "vector" — no "id".
let scan_schema = Arc::new(Schema::new(vec![
Field::new("label", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4),
false,
),
]));
let scan_provider =
Arc::new(HashKeyProvider::try_new(scan_schema, vec![], "label").expect("HashKeyProvider"));

// lookup_provider has "id".
let lookup_schema = exec_schema();
let lookup_provider =
Arc::new(HashKeyProvider::try_new(lookup_schema, vec![], "id").expect("HashKeyProvider"));

let reg = USearchRegistry::new();
let result = reg.add(
"test::vector",
make_populated_index(),
scan_provider,
lookup_provider,
"id",
MetricKind::L2sq,
ScalarKind::F32,
);
assert!(
result.is_err(),
"registration must fail when scan_provider lacks key column"
);
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("scan provider"),
"error must mention scan provider: {msg}"
);
}

/// Qualified table, WHERE clause, ORDER BY UDF directly.
#[tokio::test]
async fn exec_qualified_where_order_by_udf() {
Expand Down
2 changes: 2 additions & 0 deletions tests/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async fn make_ctx(metric: MetricKind) -> SessionContext {
"items::vector",
make_index(metric),
provider.clone(),
provider.clone(),
"id",
metric,
ScalarKind::F32,
Expand Down Expand Up @@ -347,6 +348,7 @@ async fn make_ctx_qualified(metric: MetricKind) -> SessionContext {
"datafusion::public::items::vector",
make_index(metric),
provider.clone(),
provider.clone(),
"id",
metric,
ScalarKind::F32,
Expand Down
Loading