2626COLLECTION_NAME = 'ccb_store'
2727DOCUMENTS_TABLE_NAME = 'docs'
2828ACCESS_LIST_TABLE_NAME = 'access_list'
29+ PG_BATCH_SIZE = 50000
2930
3031logger = logging .getLogger ('ccb.vectordb' )
3132
@@ -130,11 +131,17 @@ def get_users(self) -> list[str]:
130131 def add_indocuments (self , indocuments : list [InDocument ]) -> tuple [list [str ], list [str ]]:
131132 added_sources = []
132133 not_added_sources = []
134+ batch_size = PG_BATCH_SIZE // 5
133135
134136 with self .session_maker () as session :
135137 for indoc in indocuments :
136138 try :
137- chunk_ids = self .client .add_documents (indoc .documents )
139+ # query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
140+ # so we chunk the documents into (5 values * 10k) chunks
141+ # change the chunk size when there are more inserted values per document
142+ chunk_ids = []
143+ for i in range (0 , len (indoc .documents ), batch_size ):
144+ chunk_ids .extend (self .client .add_documents (indoc .documents [i :i + batch_size ]))
138145
139146 doc = DocumentsStore (
140147 source_id = indoc .source_id ,
@@ -497,24 +504,17 @@ def doc_search(
497504
498505 try :
499506 with self .session_maker () as session :
500- # get user's access list
501- stmt = (
502- sa .select (AccessListStore .source_id )
503- .filter (AccessListStore .uid == user_id )
504- )
505- result = session .execute (stmt ).fetchall ()
506- source_ids = [r .source_id for r in result ]
507-
508- doc_filters = [DocumentsStore .source_id .in_ (source_ids )]
507+ doc_filters = [AccessListStore .uid == user_id ]
509508 match scope_type :
510509 case ScopeType .PROVIDER :
511510 doc_filters .append (DocumentsStore .provider .in_ (scope_list )) # pyright: ignore[reportArgumentType]
512511 case ScopeType .SOURCE :
513512 doc_filters .append (DocumentsStore .source_id .in_ (scope_list )) # pyright: ignore[reportArgumentType]
514513
515- # get chunks associated with the source_ids
514+ # get chunks associated with the user
516515 stmt = (
517516 sa .select (DocumentsStore .chunks )
517+ .join (AccessListStore , AccessListStore .source_id == DocumentsStore .source_id )
518518 .filter (* doc_filters )
519519 )
520520 result = session .execute (stmt ).fetchall ()
@@ -538,30 +538,43 @@ def _similarity_search(
538538 if not collection :
539539 raise DbException ('Collection not found' )
540540
541- filter_by = [
542- self .client .EmbeddingStore .collection_id == collection .uuid ,
543- self .client .EmbeddingStore .id .in_ (chunk_ids ),
544- ]
541+ # Initialize results list to store all potential matches
542+ all_results = []
543+ # Process chunk_ids in batches to prevent db errors
544+ # query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
545+ for i in range (0 , len (chunk_ids ), PG_BATCH_SIZE ):
546+ batch_chunk_ids = chunk_ids [i :i + PG_BATCH_SIZE ]
545547
546- results = (
547- session .query (
548- self .client .EmbeddingStore ,
549- self .client .distance_strategy (embedding ).label ('distance' ),
550- )
551- .filter (* filter_by )
552- .order_by (sa .asc ('distance' ))
553- .join (
554- self .client .CollectionStore ,
555- self .client .EmbeddingStore .collection_id == self .client .CollectionStore .uuid ,
548+ filter_by = [
549+ self .client .EmbeddingStore .collection_id == collection .uuid ,
550+ self .client .EmbeddingStore .id .in_ (batch_chunk_ids ),
551+ ]
552+
553+ batch_results = (
554+ session .query (
555+ self .client .EmbeddingStore ,
556+ self .client .distance_strategy (embedding ).label ('distance' ),
557+ )
558+ .filter (* filter_by )
559+ .join (
560+ self .client .CollectionStore ,
561+ self .client .EmbeddingStore .collection_id == self .client .CollectionStore .uuid ,
562+ )
563+ .order_by (sa .asc ('distance' ))
564+ .limit (k ) # Get up to k results from the batch
556565 )
557- .limit (k )
558- .all ()
559- )
566+
567+ all_results .extend (batch_results )
568+
569+ # Sort all collected results by distance and take top k
570+ if len (chunk_ids ) > PG_BATCH_SIZE :
571+ all_results .sort (key = lambda x : x .distance )
572+ top_k_results = all_results [:k ]
560573
561574 return [
562575 Document (
563576 id = str (result .EmbeddingStore .id ),
564577 page_content = result .EmbeddingStore .document ,
565578 metadata = result .EmbeddingStore .cmetadata ,
566- ) for result in results
579+ ) for result in top_k_results
567580 ]
0 commit comments