Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# CLAUDE.md

## Project Overview

A DataFusion extension crate for vector similarity search (ANN) using USearch HNSW indices. Provides optimizer rules, UDFs (l2_distance, cosine_distance, negative_dot), and pluggable lookup providers (Parquet, SQLite) for retrieving non-embedding columns by key.

## Git Commits

Use conventional format: `<type>(<scope>): <subject>` where type = feat|fix|docs|style|refactor|test|chore|perf. Subject: 50 chars max, imperative mood ("add" not "added"), no period.

## Pre-Push Checklist

Always run these before pushing:

1. `cargo fmt --check` — fix any formatting issues
2. `cargo clippy -- -D warnings` — no warnings allowed
3. `cargo test --features sqlite-provider` — all tests must pass (sqlite-provider feature needed for SQLite provider tests)
25 changes: 10 additions & 15 deletions src/sqlite_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use datafusion::physical_plan::{
use rusqlite::{Connection, types::Value as SqlValue};
use tokio::sync::Semaphore;

use crate::keys::{DatasetLayout, pack_key};
use crate::lookup::PointLookupProvider;

// ── Provider ──────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -106,15 +105,15 @@ impl SqliteLookupProvider {
/// parquet files on first run. Opens a pool of `pool_size` read
/// connections (WAL allows N concurrent readers).
///
/// `local_parquet_files`, `layout`, `schema`, and `parquet_col_indices`
/// are only used if the table does not yet exist.
#[allow(clippy::too_many_arguments)]
/// `local_parquet_files`, `schema`, and `parquet_col_indices`
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: callers pairing this provider with a USearch index now need to use monotonic keys (0, 1, 2, …) when calling index.add() — the previous implicit contract was pack_key(file_idx, rg_idx, lo). Worth a one-liner here so the contract is visible at the call site:

Suggested change
/// `local_parquet_files`, `schema`, and `parquet_col_indices`
/// are only used if the table does not yet exist. Row keys are assigned
/// as monotonic integers (0, 1, 2, …) in file-iteration order; any
/// USearch index used alongside this provider must use the same scheme.

/// are only used if the table does not yet exist. Row keys are assigned
/// as monotonic integers (0, 1, 2, …) in file-iteration order; any
/// USearch index used alongside this provider must use the same scheme.
pub fn open_or_build(
db_path: &str,
table_name: &str,
pool_size: usize,
local_parquet_files: &[String],
layout: &DatasetLayout,
schema: SchemaRef,
parquet_col_indices: &[usize],
) -> DFResult<Self> {
Expand Down Expand Up @@ -156,7 +155,6 @@ impl SqliteLookupProvider {
&conn,
table_name,
local_parquet_files,
layout,
&schema,
parquet_col_indices,
)?;
Expand Down Expand Up @@ -581,7 +579,6 @@ fn build_table(
conn: &Connection,
table_name: &str,
parquet_files: &[String],
layout: &DatasetLayout,
schema: &SchemaRef,
parquet_col_indices: &[usize],
) -> DFResult<()> {
Expand Down Expand Up @@ -628,7 +625,9 @@ fn build_table(
.prepare(&insert_sql)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;

for (file_idx, file_path) in parquet_files.iter().enumerate() {
let mut global_row_idx: u64 = 0;

for file_path in parquet_files {
let f = std::fs::File::open(file_path)
.map_err(|e| DataFusionError::Execution(format!("open {file_path}: {e}")))?;
let builder = parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(f)
Expand All @@ -637,20 +636,17 @@ fn build_table(
.with_batch_size(2048)
.build()
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
let mut file_row: u64 = 0;

for batch_result in reader {
let batch = batch_result.map_err(|e| DataFusionError::Execution(e.to_string()))?;
let n = batch.num_rows();

for row_i in 0..n {
let r = file_row + row_i as u64;
let rg = layout.rg_cum_rows[file_idx].partition_point(|&s| s <= r) - 1;
let lo = (r - layout.rg_cum_rows[file_idx][rg]) as usize;
let packed_key = pack_key(file_idx, rg, lo);
let key = global_row_idx;
global_row_idx += 1;

let mut params: Vec<SqlValue> = Vec::with_capacity(schema.fields().len());
params.push(SqlValue::Integer(packed_key as i64));
params.push(SqlValue::Integer(key as i64));

for &ci in parquet_col_indices {
params.push(arrow_cell_to_sql(batch.column(ci), row_i));
Expand All @@ -659,7 +655,6 @@ fn build_table(
stmt.execute(rusqlite::params_from_iter(params.iter()))
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
}
file_row += n as u64;
}
}
}
Expand Down
37 changes: 6 additions & 31 deletions tests/sqlite_provider_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion::catalog::TableProvider;
use datafusion::prelude::SessionContext;
use datafusion_vector_search_ext::{
DatasetLayout, PointLookupProvider, SqliteLookupProvider, pack_key,
};
use datafusion_vector_search_ext::{PointLookupProvider, SqliteLookupProvider};
use parquet::arrow::ArrowWriter;
use tempfile::tempdir;

Expand Down Expand Up @@ -38,13 +36,6 @@ fn make_provider(dir: &tempfile::TempDir) -> SqliteLookupProvider {
writer.write(&batch).unwrap();
writer.close().unwrap();

// Build a minimal DatasetLayout for 1 file with 1 row group of 3 rows.
let layout = DatasetLayout {
file_keys: vec!["parquet/test.parquet".to_string()],
file_cum_rows: vec![0, 3],
rg_cum_rows: vec![vec![0, 3]],
};

let db_path = dir.path().join("test.db");
let parquet_files = vec![parquet_path.to_str().unwrap().to_string()];

Expand All @@ -53,7 +44,6 @@ fn make_provider(dir: &tempfile::TempDir) -> SqliteLookupProvider {
"models",
4,
&parquet_files,
&layout,
provider_schema,
&[0], // parquet col 0 (name) → provider col 1
)
Expand All @@ -65,10 +55,8 @@ async fn test_fetch_existing_keys() {
let dir = tempdir().unwrap();
let provider = make_provider(&dir);

let key0 = pack_key(0, 0, 0);
let key2 = pack_key(0, 0, 2);
let batches = provider
.fetch_by_keys(&[key0, key2], "row_idx", None)
.fetch_by_keys(&[0, 2], "row_idx", None)
.await
.unwrap();

Expand Down Expand Up @@ -98,10 +86,9 @@ async fn test_projection() {
let dir = tempdir().unwrap();
let provider = make_provider(&dir);

let key1 = pack_key(0, 0, 1);
// Project only row_idx (index 0).
let batches = provider
.fetch_by_keys(&[key1], "row_idx", Some(&[0]))
.fetch_by_keys(&[1], "row_idx", Some(&[0]))
.await
.unwrap();

Expand All @@ -114,17 +101,16 @@ async fn test_projection() {
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
assert_eq!(row_idx_col.value(0), key1);
assert_eq!(row_idx_col.value(0), 1);
}

#[tokio::test]
async fn test_missing_keys_return_empty() {
let dir = tempdir().unwrap();
let provider = make_provider(&dir);

let missing = pack_key(0, 0, 99); // offset 99 doesn't exist
let batches = provider
.fetch_by_keys(&[missing], "row_idx", None)
.fetch_by_keys(&[99], "row_idx", None)
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
Expand Down Expand Up @@ -203,29 +189,18 @@ async fn test_table_name_with_spaces() {
writer.write(&batch).unwrap();
writer.close().unwrap();

let layout = DatasetLayout {
file_keys: vec!["parquet/test.parquet".to_string()],
file_cum_rows: vec![0, 1],
rg_cum_rows: vec![vec![0, 1]],
};

let db_path = dir.path().join("test.db");
// Table name with spaces — previously this would have produced a SQL syntax error.
let provider = SqliteLookupProvider::open_or_build(
db_path.to_str().unwrap(),
"my models",
2,
&[parquet_path.to_str().unwrap().to_string()],
&layout,
provider_schema,
&[0],
)
.unwrap();

let key0 = pack_key(0, 0, 0);
let batches = provider
.fetch_by_keys(&[key0], "row_idx", None)
.await
.unwrap();
let batches = provider.fetch_by_keys(&[0], "row_idx", None).await.unwrap();
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 1);
}
Loading