Skip to content

Commit 42abb40

Browse files
authored
fix(sqlite-provider): use caller-provided key column name (#9)
* fix(sqlite-provider): use caller-provided key column name The SqliteLookupProvider previously hardcoded "row_idx" as the key column name in CREATE TABLE and WHERE clauses. This caused errors when callers used a different key column name (e.g. "_key"). Now derives the key column name from the first field in the provided schema, making the provider work with any key column name. * test(sqlite-provider): add test for custom key column name Exercises SqliteLookupProvider with "_key" as the key column (the scenario used by runtimedb), verifying both fetch_by_keys and projection work with non-default key column names.
1 parent 3f303d3 commit 42abb40

2 files changed

Lines changed: 94 additions & 6 deletions

File tree

src/sqlite_provider.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
//
33
// Stores all non-embedding columns in a local SQLite database (bundled libsqlite3).
44
// Scalar columns map to INTEGER/TEXT/REAL; list columns are serialised as JSON TEXT.
5-
// Lookups use `WHERE row_idx IN (?, ...)` against the INTEGER PRIMARY KEY B-tree.
5+
// Lookups use `WHERE <key_col> IN (?, ...)` against the INTEGER PRIMARY KEY B-tree.
66
//
7-
// Schema: row_idx INTEGER PRIMARY KEY, <col> TEXT/INTEGER/REAL, ...
7+
// Schema: <key_col> INTEGER PRIMARY KEY, <col> TEXT/INTEGER/REAL, ...
8+
//
9+
// The key column name is caller-provided (e.g. "_key") and must match the first
10+
// field in the schema passed to `open_or_build`.
811
//
912
// Persistence: the database is written once to the given path and reused on
1013
// subsequent runs. The first build reads all parquet files and inserts rows
@@ -42,6 +45,7 @@ use crate::lookup::PointLookupProvider;
4245
pub struct SqliteLookupProvider {
4346
schema: SchemaRef,
4447
table_name: String,
48+
key_col: String,
4549
pool: Arc<Mutex<Vec<Connection>>>,
4650
sem: Arc<Semaphore>,
4751
}
@@ -117,6 +121,8 @@ impl SqliteLookupProvider {
117121
schema: SchemaRef,
118122
parquet_col_indices: &[usize],
119123
) -> DFResult<Self> {
124+
// The first field in the schema is the key column (INTEGER PRIMARY KEY).
125+
let key_col = schema.field(0).name().clone();
120126
if pool_size == 0 {
121127
return Err(DataFusionError::Execution(
122128
"pool_size must be at least 1".into(),
@@ -167,6 +173,7 @@ impl SqliteLookupProvider {
167173
Ok(Self {
168174
schema,
169175
table_name: table_name.to_string(),
176+
key_col,
170177
pool: Arc::new(Mutex::new(conns)),
171178
sem: Arc::new(Semaphore::new(pool_size)),
172179
})
@@ -202,6 +209,7 @@ impl PointLookupProvider for SqliteLookupProvider {
202209
let keys_vec = keys.to_vec();
203210
let pool = self.pool.clone();
204211
let table_name = self.table_name.clone();
212+
let key_col = self.key_col.clone();
205213

206214
// Acquire a semaphore permit to bound concurrency to the pool size,
207215
// then run the synchronous SQLite query on a blocking thread.
@@ -227,6 +235,7 @@ impl PointLookupProvider for SqliteLookupProvider {
227235
&keys_vec,
228236
&out_schema,
229237
&table_name,
238+
&key_col,
230239
);
231240
drop(guard); // explicit but not required — Drop handles it
232241
res
@@ -243,6 +252,7 @@ fn execute_query_sync(
243252
keys: &[u64],
244253
out_schema: &SchemaRef,
245254
table_name: &str,
255+
key_col: &str,
246256
) -> DFResult<Vec<RecordBatch>> {
247257
let placeholders = keys.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
248258
// Select only the columns in out_schema (already projection-applied by the
@@ -253,8 +263,9 @@ fn execute_query_sync(
253263
.map(|f| quote_ident(f.name()))
254264
.collect::<Vec<_>>()
255265
.join(", ");
266+
let qk = quote_ident(key_col);
256267
let sql = format!(
257-
"SELECT {col_list} FROM {tn} WHERE row_idx IN ({placeholders}) ORDER BY row_idx",
268+
"SELECT {col_list} FROM {tn} WHERE {qk} IN ({placeholders}) ORDER BY {qk}",
258269
tn = quote_ident(table_name)
259270
);
260271

@@ -586,14 +597,16 @@ fn build_table(
586597
schema: &SchemaRef,
587598
parquet_col_indices: &[usize],
588599
) -> DFResult<()> {
600+
// The first field is the key column (INTEGER PRIMARY KEY).
601+
let key_col_name = schema.field(0).name();
589602
let col_defs = schema
590603
.fields()
591604
.iter()
592605
.map(|f| {
593-
let sql_type = arrow_type_to_sql(f.data_type());
594-
if f.name() == "row_idx" {
595-
"row_idx INTEGER PRIMARY KEY".to_string()
606+
if f.name() == key_col_name {
607+
format!("{} INTEGER PRIMARY KEY", quote_ident(f.name()))
596608
} else {
609+
let sql_type = arrow_type_to_sql(f.data_type());
597610
format!("{} {}", quote_ident(f.name()), sql_type)
598611
}
599612
})

tests/sqlite_provider_test.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,78 @@ async fn test_table_name_with_spaces() {
204204
let batches = provider.fetch_by_keys(&[0], "row_idx", None).await.unwrap();
205205
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 1);
206206
}
207+
208+
/// Verify that a non-default key column name (e.g. "_key") works correctly.
209+
/// This is the scenario used by runtimedb where Parquet files have a `_key` column.
210+
#[tokio::test]
211+
async fn test_custom_key_column_name() {
212+
let dir = tempdir().unwrap();
213+
214+
let parquet_schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)]));
215+
216+
// Provider schema uses "_key" instead of the default "row_idx".
217+
let provider_schema = Arc::new(Schema::new(vec![
218+
Field::new("_key", DataType::UInt64, false),
219+
Field::new("name", DataType::Utf8, true),
220+
]));
221+
222+
let batch = RecordBatch::try_new(
223+
parquet_schema.clone(),
224+
vec![Arc::new(StringArray::from(vec![
225+
Some("alice"),
226+
Some("bob"),
227+
Some("carol"),
228+
]))],
229+
)
230+
.unwrap();
231+
232+
let parquet_path = dir.path().join("test.parquet");
233+
let file = std::fs::File::create(&parquet_path).unwrap();
234+
let mut writer = ArrowWriter::try_new(file, parquet_schema, None).unwrap();
235+
writer.write(&batch).unwrap();
236+
writer.close().unwrap();
237+
238+
let db_path = dir.path().join("test_key.db");
239+
let provider = SqliteLookupProvider::open_or_build(
240+
db_path.to_str().unwrap(),
241+
"vectors",
242+
2,
243+
&[parquet_path.to_str().unwrap().to_string()],
244+
provider_schema,
245+
&[0],
246+
)
247+
.unwrap();
248+
249+
// fetch_by_keys should work with the custom key column
250+
let batches = provider.fetch_by_keys(&[0, 2], "_key", None).await.unwrap();
251+
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
252+
assert_eq!(total_rows, 2);
253+
254+
let names: Vec<String> = batches
255+
.iter()
256+
.flat_map(|b| {
257+
b.column_by_name("name")
258+
.unwrap()
259+
.as_any()
260+
.downcast_ref::<StringArray>()
261+
.unwrap()
262+
.iter()
263+
.map(|v| v.unwrap().to_string())
264+
.collect::<Vec<_>>()
265+
})
266+
.collect();
267+
assert_eq!(names, vec!["alice", "carol"]);
268+
269+
// projection to only the key column should also work
270+
let batches = provider
271+
.fetch_by_keys(&[1], "_key", Some(&[0]))
272+
.await
273+
.unwrap();
274+
assert_eq!(batches[0].schema().field(0).name(), "_key");
275+
let key_col = batches[0]
276+
.column(0)
277+
.as_any()
278+
.downcast_ref::<UInt64Array>()
279+
.unwrap();
280+
assert_eq!(key_col.value(0), 1);
281+
}

0 commit comments

Comments
 (0)