@@ -63,6 +63,13 @@ def create_connection(db_path: str) -> sqlite3.Connection:
6363 connection .enable_load_extension (True )
6464 sqlite_vec .load (connection )
6565 connection .enable_load_extension (False )
66+
67+ # Performance optimizations
68+ connection .execute ("PRAGMA journal_mode=WAL" )
69+ connection .execute ("PRAGMA synchronous=NORMAL" )
70+ connection .execute ("PRAGMA cache_size=-64000" ) # 64MB cache
71+ connection .execute ("PRAGMA temp_store=MEMORY" )
72+
6673 logger .info (f"Successfully connected to database: { db_path } " )
6774 return connection
6875 except sqlite3 .Error as e :
@@ -262,31 +269,33 @@ def add(
262269 validate_embeddings_match (texts , embeddings , metadata )
263270 logger .debug (f"Adding { len (texts )} records to table '{ self .table } '" )
264271 try :
265- max_id = self .connection .execute (
266- f"SELECT max(rowid) as rowid FROM { self .table } "
267- ).fetchone ()["rowid" ]
268-
269- if max_id is None :
270- max_id = 0
271-
272272 if metadata is None :
273273 metadata = [dict () for _ in texts ]
274274
275275 data_input = [
276276 (text , json .dumps (md ), serialize_f32 (embedding ))
277277 for text , md , embedding in zip (texts , metadata , embeddings )
278278 ]
279- self .connection .executemany (
279+
280+ cur = self .connection .cursor ()
281+
282+ # Get max rowid before insert
283+ max_before = cur .execute (
284+ f"SELECT COALESCE(MAX(rowid), 0) FROM { self .table } "
285+ ).fetchone ()[0 ]
286+
287+ cur .executemany (
280288 f"""INSERT INTO { self .table } (text, metadata, text_embedding)
281289 VALUES (?,?,?)""" ,
282290 data_input ,
283291 )
292+
293+ # Calculate rowids from max_before
294+ rowids = list (range (max_before + 1 , max_before + len (texts ) + 1 ))
295+
284296 if not self ._in_transaction :
285297 self .connection .commit ()
286- results = self .connection .execute (
287- f"SELECT rowid FROM { self .table } WHERE rowid > { max_id } "
288- )
289- rowids = [row ["rowid" ] for row in results ]
298+
290299 logger .info (f"Added { len (rowids )} records to table '{ self .table } '" )
291300 return rowids
292301 except sqlite3 .OperationalError as e :
@@ -475,10 +484,93 @@ def update_many(
475484 if not updates :
476485 return 0
477486 logger .debug (f"Updating { len (updates )} records" )
478- updated_count = 0
487+
488+ # Group updates by which fields are being updated
489+ text_updates = []
490+ metadata_updates = []
491+ embedding_updates = []
492+ full_updates = []
493+
494+ mixed_updates = []
495+
479496 for rowid , text , metadata , embedding in updates :
480- if self .update (rowid , text = text , metadata = metadata , embedding = embedding ):
481- updated_count += 1
497+ has_text = text is not None
498+ has_metadata = metadata is not None
499+ has_embedding = embedding is not None
500+
501+ if has_text and has_metadata and has_embedding :
502+ if text is not None and metadata is not None and embedding is not None :
503+ full_updates .append (
504+ (text , json .dumps (metadata ), serialize_f32 (embedding ), rowid )
505+ )
506+ elif has_text and not has_metadata and not has_embedding :
507+ text_updates .append ((text , rowid ))
508+ elif has_metadata and not has_text and not has_embedding :
509+ metadata_updates .append ((json .dumps (metadata ), rowid ))
510+ elif has_embedding and not has_text and not has_metadata :
511+ if embedding is not None :
512+ embedding_updates .append ((serialize_f32 (embedding ), rowid ))
513+ else :
514+ # Mixed updates - store for individual execution
515+ mixed_updates .append ((rowid , text , metadata , embedding ))
516+
517+ cur = self .connection .cursor ()
518+ updated_count = 0
519+
520+ # Batch execute grouped updates
521+ if full_updates :
522+ cur .executemany (
523+ f"""
524+ UPDATE { self .table }
525+ SET text = ?, metadata = ?, text_embedding = ? WHERE rowid = ?
526+ """ ,
527+ full_updates ,
528+ )
529+ updated_count += cur .rowcount
530+
531+ if text_updates :
532+ cur .executemany (
533+ f"UPDATE { self .table } SET text = ? WHERE rowid = ?" , text_updates
534+ )
535+ updated_count += cur .rowcount
536+
537+ if metadata_updates :
538+ cur .executemany (
539+ f"UPDATE { self .table } SET metadata = ? WHERE rowid = ?" ,
540+ metadata_updates ,
541+ )
542+ updated_count += cur .rowcount
543+
544+ if embedding_updates :
545+ cur .executemany (
546+ f"UPDATE { self .table } SET text_embedding = ? WHERE rowid = ?" ,
547+ embedding_updates ,
548+ )
549+ updated_count += cur .rowcount
550+
551+ # Handle mixed updates individually
552+ for rowid , text , metadata , embedding in mixed_updates :
553+ sets = []
554+ params : list [Any ] = []
555+ if text is not None :
556+ sets .append ("text = ?" )
557+ params .append (text )
558+ if metadata is not None :
559+ sets .append ("metadata = ?" )
560+ params .append (json .dumps (metadata ))
561+ if embedding is not None :
562+ sets .append ("text_embedding = ?" )
563+ params .append (serialize_f32 (embedding ))
564+ params .append (rowid )
565+
566+ if sets :
567+ sql = f"UPDATE { self .table } SET " + ", " .join (sets ) + " WHERE rowid = ?"
568+ cur .execute (sql , params )
569+ updated_count += cur .rowcount
570+
571+ if not self ._in_transaction :
572+ self .connection .commit ()
573+
482574 logger .info (f"Updated { updated_count } records in table '{ self .table } '" )
483575 return updated_count
484576
@@ -493,13 +585,26 @@ def get_all(self, batch_size: int = 100) -> Generator[Result, None, None]:
493585 """
494586 validate_limit (batch_size )
495587 logger .debug (f"Fetching all records with batch_size={ batch_size } " )
496- offset = 0
588+ last_rowid = 0
589+ cursor = self .connection .cursor ()
590+
497591 while True :
498- batch = self .list_results (limit = batch_size , offset = offset )
499- if not batch :
592+ cursor .execute (
593+ f"""
594+ SELECT rowid, text, metadata, text_embedding FROM { self .table }
595+ WHERE rowid > ?
596+ ORDER BY rowid ASC
597+ LIMIT ?
598+ """ ,
599+ [last_rowid , batch_size ],
600+ )
601+ rows = cursor .fetchall ()
602+ if not rows :
500603 break
501- yield from batch
502- offset += batch_size
604+
605+ results = self .rows_to_results (rows )
606+ yield from results
607+ last_rowid = results [- 1 ][0 ] # Get last rowid from batch
503608
504609 @contextmanager
505610 def transaction (self ) -> Generator [None , None , None ]:
0 commit comments