-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
648 lines (566 loc) · 21.9 KB
/
Copy pathbase.py
File metadata and controls
648 lines (566 loc) · 21.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
"""High-level client for vector search on SQLite using the sqlite-vec extension.
This module provides `SQLiteVecClient`, a thin wrapper around `sqlite3` and
`sqlite-vec` to store texts, JSON metadata, and float32 embeddings, and to run
similarity search through a virtual vector table.
"""
from __future__ import annotations
import json
import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from types import TracebackType
from typing import Any, Literal
import sqlite_vec
from .exceptions import ConnectionError as VecConnectionError
from .exceptions import TableNotFoundError
from .logger import get_logger
from .types import Embeddings, Metadata, Result, Rowids, SimilaritySearchResult, Text
from .utils import deserialize_f32, serialize_f32
from .validation import (
validate_dimension,
validate_embeddings_match,
validate_limit,
validate_offset,
validate_table_name,
validate_top_k,
)
logger = get_logger()
class SQLiteVecClient:
"""Manage a text+embedding table and its sqlite-vec index.
The client maintains two tables:
- `{table}`: base table with columns `text`, `metadata`, `text_embedding`.
- `{table}_vec`: `vec0` virtual table mirroring embeddings for ANN search.
It exposes CRUD helpers and `similarity_search` over embeddings.
"""
@staticmethod
def create_connection(db_path: str) -> sqlite3.Connection:
"""Create a SQLite connection with sqlite-vec extension loaded.
Args:
db_path: Path to SQLite database file
Returns:
SQLite connection with sqlite-vec loaded
Raises:
VecConnectionError: If connection or extension loading fails
"""
try:
logger.debug(f"Connecting to database: {db_path}")
connection = sqlite3.connect(db_path)
connection.row_factory = sqlite3.Row
connection.enable_load_extension(True)
sqlite_vec.load(connection)
connection.enable_load_extension(False)
# Performance optimizations
connection.execute("PRAGMA journal_mode=WAL")
connection.execute("PRAGMA synchronous=NORMAL")
connection.execute("PRAGMA cache_size=-64000") # 64MB cache
connection.execute("PRAGMA temp_store=MEMORY")
logger.info(f"Successfully connected to database: {db_path}")
return connection
except sqlite3.Error as e:
logger.error(f"Failed to connect to database {db_path}: {e}")
raise VecConnectionError(f"Failed to connect to database: {e}") from e
except Exception as e:
logger.error(f"Failed to load sqlite-vec extension: {e}")
raise VecConnectionError(f"Failed to load sqlite-vec extension: {e}") from e
@staticmethod
def rows_to_results(rows: list[sqlite3.Row]) -> list[Result]:
"""Convert `sqlite3.Row` items into `(rowid, text, metadata, embedding)`."""
return [
(
row["rowid"],
row["text"],
json.loads(row["metadata"]),
deserialize_f32(row["text_embedding"]),
)
for row in rows
]
def __init__(self, table: str, db_path: str) -> None:
"""Initialize the client for a given base table and database file.
Args:
table: Name of the base table
db_path: Path to SQLite database file
Raises:
TableNameError: If table name is invalid
VecConnectionError: If connection fails
"""
validate_table_name(table)
self.table = table
self._in_transaction = False
logger.debug(f"Initializing SQLiteVecClient for table: {table}")
self.connection = self.create_connection(db_path)
def __enter__(self) -> SQLiteVecClient:
"""Support context manager protocol and return `self`."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
"""Close the connection on exit; do not suppress exceptions."""
self.close()
def create_table(
self,
dim: int,
distance: Literal["L1", "L2", "cosine"] = "cosine",
) -> None:
"""Create base table, vector table, and triggers to keep them in sync.
Args:
dim: Embedding dimension (must be positive)
distance: Distance metric for similarity search
Raises:
ValidationError: If dimension is invalid
"""
validate_dimension(dim)
logger.info(
f"Creating table '{self.table}' with dim={dim}, distance={distance}"
)
self.connection.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table}
(
rowid INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT,
metadata BLOB,
text_embedding BLOB
)
;
"""
)
self.connection.execute(
f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.table}_vec USING vec0(
rowid INTEGER PRIMARY KEY,
text_embedding float[{dim}] distance_metric={distance}
)
;
"""
)
self.connection.execute(
f"""
CREATE TRIGGER IF NOT EXISTS {self.table}_embed_text
AFTER INSERT ON {self.table}
BEGIN
INSERT INTO {self.table}_vec(rowid, text_embedding)
VALUES (new.rowid, new.text_embedding)
;
END;
"""
)
self.connection.execute(
f"""
CREATE TRIGGER IF NOT EXISTS {self.table}_update_text_embedding
AFTER UPDATE OF text_embedding ON {self.table}
BEGIN
UPDATE {self.table}_vec
SET text_embedding = new.text_embedding
WHERE rowid = new.rowid
;
END;
"""
)
self.connection.execute(
f"""
CREATE TRIGGER IF NOT EXISTS {self.table}_delete_row
AFTER DELETE ON {self.table}
BEGIN
DELETE FROM {self.table}_vec WHERE rowid = old.rowid
;
END;
"""
)
self.connection.commit()
logger.debug(f"Table '{self.table}' and triggers created successfully")
def similarity_search(
self,
embedding: Embeddings,
top_k: int = 5,
) -> list[SimilaritySearchResult]:
"""Return top-k nearest neighbors for the given embedding.
Args:
embedding: Query embedding vector
top_k: Number of results to return (must be positive)
Returns:
List of (rowid, text, distance) tuples
Raises:
ValidationError: If top_k is invalid
TableNotFoundError: If table doesn't exist
"""
validate_top_k(top_k)
logger.debug(f"Performing similarity search with top_k={top_k}")
try:
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT
e.rowid AS rowid,
text,
distance
FROM {self.table} AS e
INNER JOIN {self.table}_vec AS v on v.rowid = e.rowid
WHERE
v.text_embedding MATCH ?
AND k = ?
ORDER BY v.distance
""",
[serialize_f32(embedding), top_k],
)
results = cursor.fetchall()
logger.debug(f"Similarity search returned {len(results)} results")
return [(row["rowid"], row["text"], row["distance"]) for row in results]
except sqlite3.OperationalError as e:
if "no such table" in str(e).lower():
logger.error(f"Table '{self.table}' not found during similarity search")
raise TableNotFoundError(
f"Table '{self.table}' or '{self.table}_vec' does not exist. "
"Call create_table() first."
) from e
raise
def add(
self,
texts: list[Text],
embeddings: list[Embeddings],
metadata: list[Metadata] | None = None,
) -> Rowids:
"""Insert texts with embeddings (and optional metadata) and return rowids.
Args:
texts: List of text strings
embeddings: List of embedding vectors
metadata: Optional list of metadata dicts
Returns:
List of rowids for inserted records
Raises:
ValidationError: If list lengths don't match
TableNotFoundError: If table doesn't exist
"""
validate_embeddings_match(texts, embeddings, metadata)
logger.debug(f"Adding {len(texts)} records to table '{self.table}'")
try:
if metadata is None:
metadata = [dict() for _ in texts]
data_input = [
(text, json.dumps(md), serialize_f32(embedding))
for text, md, embedding in zip(texts, metadata, embeddings)
]
cur = self.connection.cursor()
# Get max rowid before insert
max_before = cur.execute(
f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}"
).fetchone()[0]
cur.executemany(
f"""INSERT INTO {self.table}(text, metadata, text_embedding)
VALUES (?,?,?)""",
data_input,
)
# Calculate rowids from max_before
rowids = list(range(max_before + 1, max_before + len(texts) + 1))
if not self._in_transaction:
self.connection.commit()
logger.info(f"Added {len(rowids)} records to table '{self.table}'")
return rowids
except sqlite3.OperationalError as e:
if "no such table" in str(e).lower():
logger.error(f"Table '{self.table}' not found during add operation")
raise TableNotFoundError(
f"Table '{self.table}' does not exist. Call create_table() first."
) from e
raise
def get_by_id(self, rowid: int) -> Result | None:
"""Get a single record by rowid; return `None` if not found."""
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT rowid, text, metadata, text_embedding
FROM {self.table} WHERE rowid = ?
""",
[rowid],
)
row = cursor.fetchone()
if row is None:
return None
return self.rows_to_results([row])[0]
def get_many(self, rowids: list[int]) -> list[Result]:
"""Get multiple records by rowids; returns empty list if input is empty."""
if not rowids:
return []
placeholders = ",".join(["?"] * len(rowids))
cursor = self.connection.cursor()
cursor.execute(
f"""SELECT rowid, text, metadata, text_embedding FROM {self.table}
WHERE rowid IN ({placeholders})""",
rowids,
)
rows = cursor.fetchall()
return self.rows_to_results(rows)
def get_by_text(self, text: str) -> list[Result]:
"""Get all records with exact `text`, ordered by rowid ascending."""
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT rowid, text, metadata, text_embedding FROM {self.table}
WHERE text = ?
ORDER BY rowid ASC
""",
[text],
)
rows = cursor.fetchall()
return self.rows_to_results(rows)
def get_by_metadata(self, metadata: dict[str, Any]) -> list[Result]:
"""Get all records whose metadata exactly equals the given dict."""
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT rowid, text, metadata, text_embedding FROM {self.table}
WHERE metadata = ?
ORDER BY rowid ASC
""",
[json.dumps(metadata)],
)
rows = cursor.fetchall()
return self.rows_to_results(rows)
def list_results(
self,
limit: int = 50,
offset: int = 0,
order: Literal["asc", "desc"] = "asc",
) -> list[Result]:
"""List records with pagination and order by rowid.
Args:
limit: Maximum number of results (must be positive)
offset: Number of results to skip (must be non-negative)
order: Sort order ('asc' or 'desc')
Returns:
List of (rowid, text, metadata, embedding) tuples
Raises:
ValidationError: If limit or offset is invalid
"""
validate_limit(limit)
validate_offset(offset)
cursor = self.connection.cursor()
cursor.execute(
f"""
SELECT rowid, text, metadata, text_embedding FROM {self.table}
ORDER BY rowid {order.upper()}
LIMIT ? OFFSET ?
""",
[limit, offset],
)
rows = cursor.fetchall()
return self.rows_to_results(rows)
def count(self) -> int:
"""Return the total number of rows in the base table."""
cursor = self.connection.cursor()
cursor.execute(f"SELECT COUNT(1) as c FROM {self.table}")
row = cursor.fetchone()
return int(row["c"]) if row is not None else 0
def update(
self,
rowid: int,
*,
text: str | None = None,
metadata: Metadata | None = None,
embedding: Embeddings | None = None,
) -> bool:
"""Update fields of a record by rowid; return True if a row changed."""
logger.debug(f"Updating record with rowid={rowid}")
sets = []
params: list[Any] = []
if text is not None:
sets.append("text = ?")
params.append(text)
if metadata is not None:
sets.append("metadata = ?")
params.append(json.dumps(metadata))
if embedding is not None:
sets.append("text_embedding = ?")
params.append(serialize_f32(embedding))
if not sets:
return False
params.append(rowid)
sql = f"UPDATE {self.table} SET " + ", ".join(sets) + " WHERE rowid = ?"
cur = self.connection.cursor()
cur.execute(sql, params)
if not self._in_transaction:
self.connection.commit()
updated = cur.rowcount > 0
if updated:
logger.debug(f"Successfully updated record with rowid={rowid}")
return updated
def delete_by_id(self, rowid: int) -> bool:
"""Delete a single record by rowid; return True if a row was removed."""
logger.debug(f"Deleting record with rowid={rowid}")
cur = self.connection.cursor()
cur.execute(f"DELETE FROM {self.table} WHERE rowid = ?", [rowid])
if not self._in_transaction:
self.connection.commit()
deleted = cur.rowcount > 0
if deleted:
logger.debug(f"Successfully deleted record with rowid={rowid}")
return deleted
def delete_many(self, rowids: list[int]) -> int:
"""Delete multiple records by rowids; return number of rows removed."""
if not rowids:
return 0
logger.debug(f"Deleting {len(rowids)} records")
# SQLite has a limit on SQL variables (typically 999 or 32766)
# Split into chunks to avoid "too many SQL variables" error
chunk_size = 500
cur = self.connection.cursor()
deleted_count = 0
for i in range(0, len(rowids), chunk_size):
chunk = rowids[i : i + chunk_size]
placeholders = ",".join(["?"] * len(chunk))
cur.execute(
f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})",
chunk,
)
deleted_count += cur.rowcount
if not self._in_transaction:
self.connection.commit()
logger.info(f"Deleted {deleted_count} records from table '{self.table}'")
return deleted_count
def update_many(
self,
updates: list[tuple[int, str | None, Metadata | None, Embeddings | None]],
) -> int:
"""Update multiple records in a single transaction.
Args:
updates: List of (rowid, text, metadata, embedding) tuples.
Any field except rowid can be None to skip updating.
Returns:
Number of rows updated
"""
if not updates:
return 0
logger.debug(f"Updating {len(updates)} records")
# Group updates by which fields are being updated
text_updates = []
metadata_updates = []
embedding_updates = []
full_updates = []
mixed_updates = []
for rowid, text, metadata, embedding in updates:
has_text = text is not None
has_metadata = metadata is not None
has_embedding = embedding is not None
if has_text and has_metadata and has_embedding:
if text is not None and metadata is not None and embedding is not None:
full_updates.append(
(text, json.dumps(metadata), serialize_f32(embedding), rowid)
)
elif has_text and not has_metadata and not has_embedding:
text_updates.append((text, rowid))
elif has_metadata and not has_text and not has_embedding:
metadata_updates.append((json.dumps(metadata), rowid))
elif has_embedding and not has_text and not has_metadata:
if embedding is not None:
embedding_updates.append((serialize_f32(embedding), rowid))
else:
# Mixed updates - store for individual execution
mixed_updates.append((rowid, text, metadata, embedding))
cur = self.connection.cursor()
updated_count = 0
# Batch execute grouped updates
if full_updates:
cur.executemany(
f"""
UPDATE {self.table}
SET text = ?, metadata = ?, text_embedding = ? WHERE rowid = ?
""",
full_updates,
)
updated_count += cur.rowcount
if text_updates:
cur.executemany(
f"UPDATE {self.table} SET text = ? WHERE rowid = ?", text_updates
)
updated_count += cur.rowcount
if metadata_updates:
cur.executemany(
f"UPDATE {self.table} SET metadata = ? WHERE rowid = ?",
metadata_updates,
)
updated_count += cur.rowcount
if embedding_updates:
cur.executemany(
f"UPDATE {self.table} SET text_embedding = ? WHERE rowid = ?",
embedding_updates,
)
updated_count += cur.rowcount
# Handle mixed updates individually
for rowid, text, metadata, embedding in mixed_updates:
sets = []
params: list[Any] = []
if text is not None:
sets.append("text = ?")
params.append(text)
if metadata is not None:
sets.append("metadata = ?")
params.append(json.dumps(metadata))
if embedding is not None:
sets.append("text_embedding = ?")
params.append(serialize_f32(embedding))
params.append(rowid)
if sets:
sql = f"UPDATE {self.table} SET " + ", ".join(sets) + " WHERE rowid = ?"
cur.execute(sql, params)
updated_count += cur.rowcount
if not self._in_transaction:
self.connection.commit()
logger.info(f"Updated {updated_count} records in table '{self.table}'")
return updated_count
def get_all(self, batch_size: int = 100) -> Generator[Result, None, None]:
"""Yield all records in batches for memory-efficient iteration.
Args:
batch_size: Number of records to fetch per batch
Yields:
Individual (rowid, text, metadata, embedding) tuples
"""
validate_limit(batch_size)
logger.debug(f"Fetching all records with batch_size={batch_size}")
last_rowid = 0
cursor = self.connection.cursor()
while True:
cursor.execute(
f"""
SELECT rowid, text, metadata, text_embedding FROM {self.table}
WHERE rowid > ?
ORDER BY rowid ASC
LIMIT ?
""",
[last_rowid, batch_size],
)
rows = cursor.fetchall()
if not rows:
break
results = self.rows_to_results(rows)
yield from results
last_rowid = results[-1][0] # Get last rowid from batch
@contextmanager
def transaction(self) -> Generator[None, None, None]:
"""Context manager for atomic transactions.
Example:
with client.transaction():
client.add([...], [...])
client.update_many([...])
"""
logger.debug("Starting transaction")
self._in_transaction = True
try:
yield
self.connection.commit()
logger.debug("Transaction committed")
except Exception as e:
self.connection.rollback()
logger.error(f"Transaction rolled back: {e}")
raise
finally:
self._in_transaction = False
def close(self) -> None:
"""Close the underlying SQLite connection, suppressing close errors."""
try:
logger.debug(f"Closing connection for table '{self.table}'")
self.connection.close()
logger.info(f"Connection closed for table '{self.table}'")
except Exception as e:
logger.warning(f"Error closing connection: {e}")