Skip to content

Commit 02a4d28

Browse files
authored
Fix: for when too many sources/chunks for postgres IN (#166)
When your context chat has too many documents/chunks (65535) passed into certain in statements it causes the query to timeout. https://www.postgresql.org/docs/current/limits.html Changes: 1. Gets the source_ids in doc_search in one sql statement using joins instead of an in 2. In similarity_search will now batch the embedding search in batches of 50k chunks 3. Document insertion happens at 10k chunks
2 parents 82f1aef + 267aa44 commit 02a4d28

1 file changed

Lines changed: 42 additions & 29 deletions

File tree

context_chat_backend/vectordb/pgvector.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
COLLECTION_NAME = 'ccb_store'
2727
DOCUMENTS_TABLE_NAME = 'docs'
2828
ACCESS_LIST_TABLE_NAME = 'access_list'
29+
PG_BATCH_SIZE = 50000
2930

3031
logger = logging.getLogger('ccb.vectordb')
3132

@@ -130,11 +131,17 @@ def get_users(self) -> list[str]:
130131
def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]:
131132
added_sources = []
132133
not_added_sources = []
134+
batch_size = PG_BATCH_SIZE // 5
133135

134136
with self.session_maker() as session:
135137
for indoc in indocuments:
136138
try:
137-
chunk_ids = self.client.add_documents(indoc.documents)
139+
# query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
140+
# so we chunk the documents into (5 values * 10k) chunks
141+
# change the chunk size when there are more inserted values per document
142+
chunk_ids = []
143+
for i in range(0, len(indoc.documents), batch_size):
144+
chunk_ids.extend(self.client.add_documents(indoc.documents[i:i+batch_size]))
138145

139146
doc = DocumentsStore(
140147
source_id=indoc.source_id,
@@ -497,24 +504,17 @@ def doc_search(
497504

498505
try:
499506
with self.session_maker() as session:
500-
# get user's access list
501-
stmt = (
502-
sa.select(AccessListStore.source_id)
503-
.filter(AccessListStore.uid == user_id)
504-
)
505-
result = session.execute(stmt).fetchall()
506-
source_ids = [r.source_id for r in result]
507-
508-
doc_filters = [DocumentsStore.source_id.in_(source_ids)]
507+
doc_filters = [AccessListStore.uid == user_id]
509508
match scope_type:
510509
case ScopeType.PROVIDER:
511510
doc_filters.append(DocumentsStore.provider.in_(scope_list)) # pyright: ignore[reportArgumentType]
512511
case ScopeType.SOURCE:
513512
doc_filters.append(DocumentsStore.source_id.in_(scope_list)) # pyright: ignore[reportArgumentType]
514513

515-
# get chunks associated with the source_ids
514+
# get chunks associated with the user
516515
stmt = (
517516
sa.select(DocumentsStore.chunks)
517+
.join(AccessListStore, AccessListStore.source_id==DocumentsStore.source_id)
518518
.filter(*doc_filters)
519519
)
520520
result = session.execute(stmt).fetchall()
@@ -538,30 +538,43 @@ def _similarity_search(
538538
if not collection:
539539
raise DbException('Collection not found')
540540

541-
filter_by = [
542-
self.client.EmbeddingStore.collection_id == collection.uuid,
543-
self.client.EmbeddingStore.id.in_(chunk_ids),
544-
]
541+
# Initialize results list to store all potential matches
542+
all_results = []
543+
# Process chunk_ids in batches to prevent db errors
544+
# query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
545+
for i in range(0, len(chunk_ids), PG_BATCH_SIZE):
546+
batch_chunk_ids = chunk_ids[i:i+PG_BATCH_SIZE]
545547

546-
results = (
547-
session.query(
548-
self.client.EmbeddingStore,
549-
self.client.distance_strategy(embedding).label('distance'),
550-
)
551-
.filter(*filter_by)
552-
.order_by(sa.asc('distance'))
553-
.join(
554-
self.client.CollectionStore,
555-
self.client.EmbeddingStore.collection_id == self.client.CollectionStore.uuid,
548+
filter_by = [
549+
self.client.EmbeddingStore.collection_id == collection.uuid,
550+
self.client.EmbeddingStore.id.in_(batch_chunk_ids),
551+
]
552+
553+
batch_results = (
554+
session.query(
555+
self.client.EmbeddingStore,
556+
self.client.distance_strategy(embedding).label('distance'),
557+
)
558+
.filter(*filter_by)
559+
.join(
560+
self.client.CollectionStore,
561+
self.client.EmbeddingStore.collection_id == self.client.CollectionStore.uuid,
562+
)
563+
.order_by(sa.asc('distance'))
564+
.limit(k) # Get up to k results from the batch
556565
)
557-
.limit(k)
558-
.all()
559-
)
566+
567+
all_results.extend(batch_results)
568+
569+
# Sort all collected results by distance and take top k
570+
if len(chunk_ids) > PG_BATCH_SIZE:
571+
all_results.sort(key=lambda x: x.distance)
572+
top_k_results = all_results[:k]
560573

561574
return [
562575
Document(
563576
id=str(result.EmbeddingStore.id),
564577
page_content=result.EmbeddingStore.document,
565578
metadata=result.EmbeddingStore.cmetadata,
566-
) for result in results
579+
) for result in top_k_results
567580
]

0 commit comments

Comments
 (0)