Skip to content

Commit bdf98ff

Browse files
committed
perf(database): optimize connection and update methods
* Implement performance optimizations for database connection. * Use WAL journal mode and adjust synchronous settings. * Improve add and update methods with batch processing. * Enhance fetching records with efficient rowid tracking.
1 parent 2363209 commit bdf98ff

1 file changed

Lines changed: 125 additions & 20 deletions

File tree

sqlite_vec_client/base.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def create_connection(db_path: str) -> sqlite3.Connection:
6363
connection.enable_load_extension(True)
6464
sqlite_vec.load(connection)
6565
connection.enable_load_extension(False)
66+
67+
# Performance optimizations
68+
connection.execute("PRAGMA journal_mode=WAL")
69+
connection.execute("PRAGMA synchronous=NORMAL")
70+
connection.execute("PRAGMA cache_size=-64000") # 64MB cache
71+
connection.execute("PRAGMA temp_store=MEMORY")
72+
6673
logger.info(f"Successfully connected to database: {db_path}")
6774
return connection
6875
except sqlite3.Error as e:
@@ -262,31 +269,33 @@ def add(
262269
validate_embeddings_match(texts, embeddings, metadata)
263270
logger.debug(f"Adding {len(texts)} records to table '{self.table}'")
264271
try:
265-
max_id = self.connection.execute(
266-
f"SELECT max(rowid) as rowid FROM {self.table}"
267-
).fetchone()["rowid"]
268-
269-
if max_id is None:
270-
max_id = 0
271-
272272
if metadata is None:
273273
metadata = [dict() for _ in texts]
274274

275275
data_input = [
276276
(text, json.dumps(md), serialize_f32(embedding))
277277
for text, md, embedding in zip(texts, metadata, embeddings)
278278
]
279-
self.connection.executemany(
279+
280+
cur = self.connection.cursor()
281+
282+
# Get max rowid before insert
283+
max_before = cur.execute(
284+
f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}"
285+
).fetchone()[0]
286+
287+
cur.executemany(
280288
f"""INSERT INTO {self.table}(text, metadata, text_embedding)
281289
VALUES (?,?,?)""",
282290
data_input,
283291
)
292+
293+
# Calculate rowids from max_before
294+
rowids = list(range(max_before + 1, max_before + len(texts) + 1))
295+
284296
if not self._in_transaction:
285297
self.connection.commit()
286-
results = self.connection.execute(
287-
f"SELECT rowid FROM {self.table} WHERE rowid > {max_id}"
288-
)
289-
rowids = [row["rowid"] for row in results]
298+
290299
logger.info(f"Added {len(rowids)} records to table '{self.table}'")
291300
return rowids
292301
except sqlite3.OperationalError as e:
@@ -475,10 +484,93 @@ def update_many(
475484
if not updates:
476485
return 0
477486
logger.debug(f"Updating {len(updates)} records")
478-
updated_count = 0
487+
488+
# Group updates by which fields are being updated
489+
text_updates = []
490+
metadata_updates = []
491+
embedding_updates = []
492+
full_updates = []
493+
494+
mixed_updates = []
495+
479496
for rowid, text, metadata, embedding in updates:
480-
if self.update(rowid, text=text, metadata=metadata, embedding=embedding):
481-
updated_count += 1
497+
has_text = text is not None
498+
has_metadata = metadata is not None
499+
has_embedding = embedding is not None
500+
501+
if has_text and has_metadata and has_embedding:
502+
if text is not None and metadata is not None and embedding is not None:
503+
full_updates.append(
504+
(text, json.dumps(metadata), serialize_f32(embedding), rowid)
505+
)
506+
elif has_text and not has_metadata and not has_embedding:
507+
text_updates.append((text, rowid))
508+
elif has_metadata and not has_text and not has_embedding:
509+
metadata_updates.append((json.dumps(metadata), rowid))
510+
elif has_embedding and not has_text and not has_metadata:
511+
if embedding is not None:
512+
embedding_updates.append((serialize_f32(embedding), rowid))
513+
else:
514+
# Mixed updates - store for individual execution
515+
mixed_updates.append((rowid, text, metadata, embedding))
516+
517+
cur = self.connection.cursor()
518+
updated_count = 0
519+
520+
# Batch execute grouped updates
521+
if full_updates:
522+
cur.executemany(
523+
f"""
524+
UPDATE {self.table}
525+
SET text = ?, metadata = ?, text_embedding = ? WHERE rowid = ?
526+
""",
527+
full_updates,
528+
)
529+
updated_count += cur.rowcount
530+
531+
if text_updates:
532+
cur.executemany(
533+
f"UPDATE {self.table} SET text = ? WHERE rowid = ?", text_updates
534+
)
535+
updated_count += cur.rowcount
536+
537+
if metadata_updates:
538+
cur.executemany(
539+
f"UPDATE {self.table} SET metadata = ? WHERE rowid = ?",
540+
metadata_updates,
541+
)
542+
updated_count += cur.rowcount
543+
544+
if embedding_updates:
545+
cur.executemany(
546+
f"UPDATE {self.table} SET text_embedding = ? WHERE rowid = ?",
547+
embedding_updates,
548+
)
549+
updated_count += cur.rowcount
550+
551+
# Handle mixed updates individually
552+
for rowid, text, metadata, embedding in mixed_updates:
553+
sets = []
554+
params: list[Any] = []
555+
if text is not None:
556+
sets.append("text = ?")
557+
params.append(text)
558+
if metadata is not None:
559+
sets.append("metadata = ?")
560+
params.append(json.dumps(metadata))
561+
if embedding is not None:
562+
sets.append("text_embedding = ?")
563+
params.append(serialize_f32(embedding))
564+
params.append(rowid)
565+
566+
if sets:
567+
sql = f"UPDATE {self.table} SET " + ", ".join(sets) + " WHERE rowid = ?"
568+
cur.execute(sql, params)
569+
updated_count += cur.rowcount
570+
571+
if not self._in_transaction:
572+
self.connection.commit()
573+
482574
logger.info(f"Updated {updated_count} records in table '{self.table}'")
483575
return updated_count
484576

@@ -493,13 +585,26 @@ def get_all(self, batch_size: int = 100) -> Generator[Result, None, None]:
493585
"""
494586
validate_limit(batch_size)
495587
logger.debug(f"Fetching all records with batch_size={batch_size}")
496-
offset = 0
588+
last_rowid = 0
589+
cursor = self.connection.cursor()
590+
497591
while True:
498-
batch = self.list_results(limit=batch_size, offset=offset)
499-
if not batch:
592+
cursor.execute(
593+
f"""
594+
SELECT rowid, text, metadata, text_embedding FROM {self.table}
595+
WHERE rowid > ?
596+
ORDER BY rowid ASC
597+
LIMIT ?
598+
""",
599+
[last_rowid, batch_size],
600+
)
601+
rows = cursor.fetchall()
602+
if not rows:
500603
break
501-
yield from batch
502-
offset += batch_size
604+
605+
results = self.rows_to_results(rows)
606+
yield from results
607+
last_rowid = results[-1][0] # Get last rowid from batch
503608

504609
@contextmanager
505610
def transaction(self) -> Generator[None, None, None]:

0 commit comments

Comments
 (0)