diff --git a/src/sqlite_provider.rs b/src/sqlite_provider.rs index b636093..2405b02 100644 --- a/src/sqlite_provider.rs +++ b/src/sqlite_provider.rs @@ -24,8 +24,14 @@ use async_trait::async_trait; use datafusion::catalog::{Session, TableProvider}; use datafusion::common::Result as DFResult; use datafusion::error::DataFusionError; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::logical_expr::{Expr, TableType}; -use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, +}; use rusqlite::{Connection, types::Value as SqlValue}; use tokio::sync::Semaphore; @@ -311,10 +317,262 @@ impl TableProvider for SqliteLookupProvider { _filters: &[Expr], _limit: Option, ) -> DFResult> { - Err(DataFusionError::NotImplemented( - "SqliteLookupProvider does not support full table scans; use fetch_by_keys".into(), - )) + Ok(Arc::new(SqliteFullScanExec::new( + self.pool.clone(), + self.sem.clone(), + self.table_name.clone(), + self.schema.clone(), + ))) + } +} + +// ── Full-scan execution plan ────────────────────────────────────────────────── + +/// Batch size used when streaming rows from SQLite during a full table scan. +/// Larger values reduce round-trip overhead; smaller values reduce peak memory. +const SCAN_BATCH_SIZE: usize = 1024; + +/// Physical execution plan that streams all rows from a SQLite table in +/// [`SCAN_BATCH_SIZE`]-row batches. Used by the adaptive filtered path in +/// `USearchExec` to evaluate WHERE-clause predicates without loading the +/// entire table into memory at once. +#[derive(Debug)] +struct SqliteFullScanExec { + pool: Arc>>, + sem: Arc, + table_name: String, + schema: SchemaRef, + properties: PlanProperties, +} + +impl SqliteFullScanExec { + fn new( + pool: Arc>>, + sem: Arc, + table_name: String, + schema: SchemaRef, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + Self { + pool, + sem, + table_name, + schema, + properties, + } + } +} + +impl DisplayAs for SqliteFullScanExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SqliteFullScanExec: table={}", self.table_name) + } +} + +impl ExecutionPlan for SqliteFullScanExec { + fn name(&self) -> &str { + "SqliteFullScanExec" + } + fn as_any(&self) -> &dyn Any { + self + } + fn properties(&self) -> &PlanProperties { + &self.properties + } + fn children(&self) -> Vec<&Arc> { + vec![] + } + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + if children.is_empty() { + Ok(self) + } else { + Err(DataFusionError::Internal( + "SqliteFullScanExec is a leaf node and takes no children".into(), + )) + } } + + fn execute( + &self, + _partition: usize, + _ctx: Arc, + ) -> DFResult { + let pool = self.pool.clone(); + let sem = Arc::clone(&self.sem); + let table_name = self.table_name.clone(); + let schema = self.schema.clone(); + + // Bounded channel: backpressure limits how many batches are buffered + // ahead of the consumer, keeping peak memory to O(batch_size × 2). + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + + let schema_task = schema.clone(); + tokio::spawn(async move { + // Acquire a semaphore permit so the scan counts against the + // same concurrency limit as fetch_by_keys. + let _permit = match sem.acquire_owned().await { + Ok(p) => p, + Err(e) => { + let _ = tx + .send(Err(DataFusionError::Execution(e.to_string()))) + .await; + return; + } + }; + + let conn = match pool.lock() { + Ok(mut g) => g.pop().ok_or_else(|| { + DataFusionError::Execution("SqliteFullScanExec: connection pool empty".into()) + }), + Err(e) => Err(DataFusionError::Execution(format!( + "connection pool mutex poisoned: {e}" + ))), + }; + let conn = match conn { + Ok(c) => c, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + + let pool_c = pool.clone(); + let tx_c = tx.clone(); + if let Err(e) = tokio::task::spawn_blocking(move || { + let guard = ConnGuard::new(pool_c, conn); + let conn = guard.conn.as_ref().unwrap(); + + let col_list = schema_task + .fields() + .iter() + .map(|f| quote_ident(f.name())) + .collect::>() + .join(", "); + // No ORDER BY — the adaptive filter doesn't require ordering. + let sql = format!("SELECT {col_list} FROM {}", quote_ident(&table_name)); + + let mut stmt = match conn.prepare(&sql) { + Ok(s) => s, + Err(e) => { + let _ = tx_c.blocking_send(Err(DataFusionError::Execution(e.to_string()))); + return; + } + }; + + let mut rows = match stmt.query([]) { + Ok(r) => r, + Err(e) => { + let _ = tx_c.blocking_send(Err(DataFusionError::Execution(e.to_string()))); + return; + } + }; + + let n_cols = schema_task.fields().len(); + let mut col_bufs: Vec> = (0..n_cols) + .map(|_| Vec::with_capacity(SCAN_BATCH_SIZE)) + .collect(); + let mut rows_in_batch = 0usize; + + loop { + match rows.next() { + Ok(Some(row)) => { + let mut row_ok = true; + for (ci, buf) in col_bufs.iter_mut().enumerate() { + match row.get::<_, SqlValue>(ci) { + Ok(v) => buf.push(v), + Err(e) => { + let _ = tx_c.blocking_send(Err( + DataFusionError::Execution(e.to_string()), + )); + row_ok = false; + break; + } + } + } + if !row_ok { + // Error already sent on the channel — skip the + // final flush entirely to avoid sending Ok after Err. + return; + } + rows_in_batch += 1; + if rows_in_batch >= SCAN_BATCH_SIZE { + let drained: Vec> = col_bufs + .iter_mut() + .map(|b| { + std::mem::replace(b, Vec::with_capacity(SCAN_BATCH_SIZE)) + }) + .collect(); + rows_in_batch = 0; + match build_scan_batch(&schema_task, drained) { + Ok(batch) => { + if tx_c.blocking_send(Ok(batch)).is_err() { + return; // consumer dropped + } + } + Err(e) => { + let _ = tx_c.blocking_send(Err(e)); + return; + } + } + } + } + Ok(None) => break, + Err(e) => { + let _ = + tx_c.blocking_send(Err(DataFusionError::Execution(e.to_string()))); + return; + } + } + } + + // Flush the last partial batch. + if rows_in_batch > 0 { + match build_scan_batch(&schema_task, col_bufs) { + Ok(batch) => { + let _ = tx_c.blocking_send(Ok(batch)); + } + Err(e) => { + let _ = tx_c.blocking_send(Err(e)); + } + } + } + }) + .await + { + let _ = tx + .send(Err(DataFusionError::Execution(format!( + "scan task panicked: {e}" + )))) + .await; + } + }); + + // Convert the channel receiver into a RecordBatch stream. + let stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|item| (item, rx)) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } +} + +/// Build a [`RecordBatch`] from column buffers of [`SqlValue`]s. +fn build_scan_batch(schema: &SchemaRef, col_bufs: Vec>) -> DFResult { + let arrays: Vec = schema + .fields() + .iter() + .zip(col_bufs) + .map(|(field, values)| sql_values_to_arrow(field.data_type(), values)) + .collect::>()?; + RecordBatch::try_new(schema.clone(), arrays) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) } // ── Build helpers ───────────────────────────────────────────────────────────── diff --git a/tests/sqlite_provider_test.rs b/tests/sqlite_provider_test.rs index a08121d..c7dfb53 100644 --- a/tests/sqlite_provider_test.rs +++ b/tests/sqlite_provider_test.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use arrow_array::{RecordBatch, StringArray, UInt64Array}; +use arrow_array::{Array, RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion::catalog::TableProvider; use datafusion::prelude::SessionContext; @@ -140,23 +140,43 @@ async fn test_empty_key_slice() { assert!(batches.is_empty()); } -/// Regression test for the silent-empty-scan bug: -/// scan() used to return an empty MemTable, producing zero rows with no error. -/// It must now return NotImplemented so callers get a clear failure. +/// scan() returns a streaming ExecutionPlan that yields all rows in batches. #[tokio::test] -async fn test_scan_returns_not_implemented() { +async fn test_scan_streams_all_rows() { + use datafusion::execution::TaskContext; + use futures::StreamExt; + let dir = tempdir().unwrap(); let provider = make_provider(&dir); let ctx = SessionContext::new(); let state = ctx.state(); - let result = provider.scan(&state, None, &[], None).await; - assert!(result.is_err()); - let err = result.unwrap_err().to_string(); - assert!( - err.contains("not support full table scans"), - "expected NotImplemented error, got: {err}" - ); + let plan = provider.scan(&state, None, &[], None).await.unwrap(); + + let task_ctx = Arc::new(TaskContext::default()); + let mut stream = plan.execute(0, task_ctx).unwrap(); + + let mut total_rows = 0usize; + let mut all_names: Vec = Vec::new(); + while let Some(batch) = stream.next().await { + let batch = batch.unwrap(); + total_rows += batch.num_rows(); + + let names_col = batch + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..names_col.len() { + all_names.push(names_col.value(i).to_string()); + } + } + + assert_eq!(total_rows, 3); + assert!(all_names.contains(&"alice".to_string())); + assert!(all_names.contains(&"bob".to_string())); + assert!(all_names.contains(&"carol".to_string())); } /// Regression test for the SQL injection fix via quote_ident: