Skip to content

Commit dba4786

Browse files
authored
feat(sqlite-provider): implement streaming full-table scan for adaptive filtering (#4)
* feat(sqlite-provider): implement streaming full-table scan for adaptive filtering SqliteLookupProvider.scan() previously returned NotImplemented, causing the adaptive filtered path in USearchExec to fail when a WHERE clause was combined with vector search. Add SqliteFullScanExec: a leaf ExecutionPlan that streams all rows from the SQLite table in 1024-row batches via a bounded tokio mpsc channel. The blocking SQLite cursor runs in spawn_blocking; the async consumer processes each batch through evaluate_filters() and drops it immediately, keeping peak memory at O(batch_size) rather than O(total_rows). The semaphore and connection pool are shared with fetch_by_keys so concurrent scans and key lookups stay within the configured pool size. * style: fix clippy and fmt warnings * fix: propagate final-batch and spawn_blocking errors Address PR review comments: - Send build_scan_batch errors for the last partial batch instead of silently dropping them (truncated scan looked like success) - Propagate spawn_blocking JoinError so panics surface as stream errors instead of a clean end-of-stream * fix: propagate poisoned mutex error instead of masking as pool empty * fix: truncate column buffers on mid-row read error Prevents mismatched buffer lengths from causing a spurious second error during the final-batch flush, which would mask the original column read failure. * test: replace NotImplemented assertion with streaming scan test Update test_scan_returns_not_implemented → test_scan_streams_all_rows: exercises the new SqliteFullScanExec end-to-end, verifying row count and data correctness across streamed batches. * fix: return immediately on column-read error instead of flushing Avoids sending Ok(batch) on the channel after an Err was already sent for a mid-row column read failure.
1 parent 285e87c commit dba4786

2 files changed

Lines changed: 294 additions & 16 deletions

File tree

src/sqlite_provider.rs

Lines changed: 262 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ use async_trait::async_trait;
2424
use datafusion::catalog::{Session, TableProvider};
2525
use datafusion::common::Result as DFResult;
2626
use datafusion::error::DataFusionError;
27+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
2728
use datafusion::logical_expr::{Expr, TableType};
28-
use datafusion::physical_plan::ExecutionPlan;
29+
use datafusion::physical_expr::EquivalenceProperties;
30+
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
31+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
32+
use datafusion::physical_plan::{
33+
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
34+
};
2935
use rusqlite::{Connection, types::Value as SqlValue};
3036
use tokio::sync::Semaphore;
3137

@@ -311,10 +317,262 @@ impl TableProvider for SqliteLookupProvider {
311317
_filters: &[Expr],
312318
_limit: Option<usize>,
313319
) -> DFResult<Arc<dyn ExecutionPlan>> {
314-
Err(DataFusionError::NotImplemented(
315-
"SqliteLookupProvider does not support full table scans; use fetch_by_keys".into(),
316-
))
320+
Ok(Arc::new(SqliteFullScanExec::new(
321+
self.pool.clone(),
322+
self.sem.clone(),
323+
self.table_name.clone(),
324+
self.schema.clone(),
325+
)))
326+
}
327+
}
328+
329+
// ── Full-scan execution plan ──────────────────────────────────────────────────
330+
331+
/// Batch size used when streaming rows from SQLite during a full table scan.
332+
/// Larger values reduce round-trip overhead; smaller values reduce peak memory.
333+
const SCAN_BATCH_SIZE: usize = 1024;
334+
335+
/// Physical execution plan that streams all rows from a SQLite table in
336+
/// [`SCAN_BATCH_SIZE`]-row batches. Used by the adaptive filtered path in
337+
/// `USearchExec` to evaluate WHERE-clause predicates without loading the
338+
/// entire table into memory at once.
339+
#[derive(Debug)]
340+
struct SqliteFullScanExec {
341+
pool: Arc<Mutex<Vec<Connection>>>,
342+
sem: Arc<Semaphore>,
343+
table_name: String,
344+
schema: SchemaRef,
345+
properties: PlanProperties,
346+
}
347+
348+
impl SqliteFullScanExec {
349+
fn new(
350+
pool: Arc<Mutex<Vec<Connection>>>,
351+
sem: Arc<Semaphore>,
352+
table_name: String,
353+
schema: SchemaRef,
354+
) -> Self {
355+
let properties = PlanProperties::new(
356+
EquivalenceProperties::new(schema.clone()),
357+
Partitioning::UnknownPartitioning(1),
358+
EmissionType::Incremental,
359+
Boundedness::Bounded,
360+
);
361+
Self {
362+
pool,
363+
sem,
364+
table_name,
365+
schema,
366+
properties,
367+
}
368+
}
369+
}
370+
371+
impl DisplayAs for SqliteFullScanExec {
372+
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
373+
write!(f, "SqliteFullScanExec: table={}", self.table_name)
374+
}
375+
}
376+
377+
impl ExecutionPlan for SqliteFullScanExec {
378+
fn name(&self) -> &str {
379+
"SqliteFullScanExec"
380+
}
381+
fn as_any(&self) -> &dyn Any {
382+
self
383+
}
384+
fn properties(&self) -> &PlanProperties {
385+
&self.properties
386+
}
387+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
388+
vec![]
389+
}
390+
fn with_new_children(
391+
self: Arc<Self>,
392+
children: Vec<Arc<dyn ExecutionPlan>>,
393+
) -> DFResult<Arc<dyn ExecutionPlan>> {
394+
if children.is_empty() {
395+
Ok(self)
396+
} else {
397+
Err(DataFusionError::Internal(
398+
"SqliteFullScanExec is a leaf node and takes no children".into(),
399+
))
400+
}
317401
}
402+
403+
fn execute(
404+
&self,
405+
_partition: usize,
406+
_ctx: Arc<TaskContext>,
407+
) -> DFResult<SendableRecordBatchStream> {
408+
let pool = self.pool.clone();
409+
let sem = Arc::clone(&self.sem);
410+
let table_name = self.table_name.clone();
411+
let schema = self.schema.clone();
412+
413+
// Bounded channel: backpressure limits how many batches are buffered
414+
// ahead of the consumer, keeping peak memory to O(batch_size × 2).
415+
let (tx, rx) = tokio::sync::mpsc::channel::<DFResult<RecordBatch>>(2);
416+
417+
let schema_task = schema.clone();
418+
tokio::spawn(async move {
419+
// Acquire a semaphore permit so the scan counts against the
420+
// same concurrency limit as fetch_by_keys.
421+
let _permit = match sem.acquire_owned().await {
422+
Ok(p) => p,
423+
Err(e) => {
424+
let _ = tx
425+
.send(Err(DataFusionError::Execution(e.to_string())))
426+
.await;
427+
return;
428+
}
429+
};
430+
431+
let conn = match pool.lock() {
432+
Ok(mut g) => g.pop().ok_or_else(|| {
433+
DataFusionError::Execution("SqliteFullScanExec: connection pool empty".into())
434+
}),
435+
Err(e) => Err(DataFusionError::Execution(format!(
436+
"connection pool mutex poisoned: {e}"
437+
))),
438+
};
439+
let conn = match conn {
440+
Ok(c) => c,
441+
Err(e) => {
442+
let _ = tx.send(Err(e)).await;
443+
return;
444+
}
445+
};
446+
447+
let pool_c = pool.clone();
448+
let tx_c = tx.clone();
449+
if let Err(e) = tokio::task::spawn_blocking(move || {
450+
let guard = ConnGuard::new(pool_c, conn);
451+
let conn = guard.conn.as_ref().unwrap();
452+
453+
let col_list = schema_task
454+
.fields()
455+
.iter()
456+
.map(|f| quote_ident(f.name()))
457+
.collect::<Vec<_>>()
458+
.join(", ");
459+
// No ORDER BY — the adaptive filter doesn't require ordering.
460+
let sql = format!("SELECT {col_list} FROM {}", quote_ident(&table_name));
461+
462+
let mut stmt = match conn.prepare(&sql) {
463+
Ok(s) => s,
464+
Err(e) => {
465+
let _ = tx_c.blocking_send(Err(DataFusionError::Execution(e.to_string())));
466+
return;
467+
}
468+
};
469+
470+
let mut rows = match stmt.query([]) {
471+
Ok(r) => r,
472+
Err(e) => {
473+
let _ = tx_c.blocking_send(Err(DataFusionError::Execution(e.to_string())));
474+
return;
475+
}
476+
};
477+
478+
let n_cols = schema_task.fields().len();
479+
let mut col_bufs: Vec<Vec<SqlValue>> = (0..n_cols)
480+
.map(|_| Vec::with_capacity(SCAN_BATCH_SIZE))
481+
.collect();
482+
let mut rows_in_batch = 0usize;
483+
484+
loop {
485+
match rows.next() {
486+
Ok(Some(row)) => {
487+
let mut row_ok = true;
488+
for (ci, buf) in col_bufs.iter_mut().enumerate() {
489+
match row.get::<_, SqlValue>(ci) {
490+
Ok(v) => buf.push(v),
491+
Err(e) => {
492+
let _ = tx_c.blocking_send(Err(
493+
DataFusionError::Execution(e.to_string()),
494+
));
495+
row_ok = false;
496+
break;
497+
}
498+
}
499+
}
500+
if !row_ok {
501+
// Error already sent on the channel — skip the
502+
// final flush entirely to avoid sending Ok after Err.
503+
return;
504+
}
505+
rows_in_batch += 1;
506+
if rows_in_batch >= SCAN_BATCH_SIZE {
507+
let drained: Vec<Vec<SqlValue>> = col_bufs
508+
.iter_mut()
509+
.map(|b| {
510+
std::mem::replace(b, Vec::with_capacity(SCAN_BATCH_SIZE))
511+
})
512+
.collect();
513+
rows_in_batch = 0;
514+
match build_scan_batch(&schema_task, drained) {
515+
Ok(batch) => {
516+
if tx_c.blocking_send(Ok(batch)).is_err() {
517+
return; // consumer dropped
518+
}
519+
}
520+
Err(e) => {
521+
let _ = tx_c.blocking_send(Err(e));
522+
return;
523+
}
524+
}
525+
}
526+
}
527+
Ok(None) => break,
528+
Err(e) => {
529+
let _ =
530+
tx_c.blocking_send(Err(DataFusionError::Execution(e.to_string())));
531+
return;
532+
}
533+
}
534+
}
535+
536+
// Flush the last partial batch.
537+
if rows_in_batch > 0 {
538+
match build_scan_batch(&schema_task, col_bufs) {
539+
Ok(batch) => {
540+
let _ = tx_c.blocking_send(Ok(batch));
541+
}
542+
Err(e) => {
543+
let _ = tx_c.blocking_send(Err(e));
544+
}
545+
}
546+
}
547+
})
548+
.await
549+
{
550+
let _ = tx
551+
.send(Err(DataFusionError::Execution(format!(
552+
"scan task panicked: {e}"
553+
))))
554+
.await;
555+
}
556+
});
557+
558+
// Convert the channel receiver into a RecordBatch stream.
559+
let stream = futures::stream::unfold(rx, |mut rx| async move {
560+
rx.recv().await.map(|item| (item, rx))
561+
});
562+
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
563+
}
564+
}
565+
566+
/// Build a [`RecordBatch`] from column buffers of [`SqlValue`]s.
567+
fn build_scan_batch(schema: &SchemaRef, col_bufs: Vec<Vec<SqlValue>>) -> DFResult<RecordBatch> {
568+
let arrays: Vec<ArrayRef> = schema
569+
.fields()
570+
.iter()
571+
.zip(col_bufs)
572+
.map(|(field, values)| sql_values_to_arrow(field.data_type(), values))
573+
.collect::<DFResult<_>>()?;
574+
RecordBatch::try_new(schema.clone(), arrays)
575+
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
318576
}
319577

320578
// ── Build helpers ─────────────────────────────────────────────────────────────

tests/sqlite_provider_test.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
use std::sync::Arc;
44

5-
use arrow_array::{RecordBatch, StringArray, UInt64Array};
5+
use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
66
use arrow_schema::{DataType, Field, Schema};
77
use datafusion::catalog::TableProvider;
88
use datafusion::prelude::SessionContext;
@@ -140,23 +140,43 @@ async fn test_empty_key_slice() {
140140
assert!(batches.is_empty());
141141
}
142142

143-
/// Regression test for the silent-empty-scan bug:
144-
/// scan() used to return an empty MemTable, producing zero rows with no error.
145-
/// It must now return NotImplemented so callers get a clear failure.
143+
/// scan() returns a streaming ExecutionPlan that yields all rows in batches.
146144
#[tokio::test]
147-
async fn test_scan_returns_not_implemented() {
145+
async fn test_scan_streams_all_rows() {
146+
use datafusion::execution::TaskContext;
147+
use futures::StreamExt;
148+
148149
let dir = tempdir().unwrap();
149150
let provider = make_provider(&dir);
150151

151152
let ctx = SessionContext::new();
152153
let state = ctx.state();
153-
let result = provider.scan(&state, None, &[], None).await;
154-
assert!(result.is_err());
155-
let err = result.unwrap_err().to_string();
156-
assert!(
157-
err.contains("not support full table scans"),
158-
"expected NotImplemented error, got: {err}"
159-
);
154+
let plan = provider.scan(&state, None, &[], None).await.unwrap();
155+
156+
let task_ctx = Arc::new(TaskContext::default());
157+
let mut stream = plan.execute(0, task_ctx).unwrap();
158+
159+
let mut total_rows = 0usize;
160+
let mut all_names: Vec<String> = Vec::new();
161+
while let Some(batch) = stream.next().await {
162+
let batch = batch.unwrap();
163+
total_rows += batch.num_rows();
164+
165+
let names_col = batch
166+
.column_by_name("name")
167+
.unwrap()
168+
.as_any()
169+
.downcast_ref::<StringArray>()
170+
.unwrap();
171+
for i in 0..names_col.len() {
172+
all_names.push(names_col.value(i).to_string());
173+
}
174+
}
175+
176+
assert_eq!(total_rows, 3);
177+
assert!(all_names.contains(&"alice".to_string()));
178+
assert!(all_names.contains(&"bob".to_string()));
179+
assert!(all_names.contains(&"carol".to_string()));
160180
}
161181

162182
/// Regression test for the SQL injection fix via quote_ident:

0 commit comments

Comments
 (0)