Skip to content

Commit 2debcb2

Browse files
authored
fix(pgvector): add chunking to prevent long list of args in queries (#290)
The error log for the non-chunked deletion query: ``` raise ex.with_traceback(None)\nsqlalchemy.exc.OperationalError: (psycopg.OperationalError) sending query and params failed: number of parameters must be between 0 and 65535 ``` Signed-off-by: Anupam Kumar <kyteinsky@gmail.com>
1 parent 15a9922 commit 2debcb2

1 file changed

Lines changed: 93 additions & 89 deletions

File tree

context_chat_backend/vectordb/pgvector.py

Lines changed: 93 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -271,20 +271,22 @@ def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.
271271
.filter(AccessListStore.source_id == source_id)
272272
)
273273
session.execute(stmt)
274-
session.commit()
275274

276-
stmt = (
277-
postgresql_dialects.insert(AccessListStore)
278-
.values([
279-
{
280-
'uid': user_id,
281-
'source_id': source_id,
282-
}
283-
for user_id in user_ids
284-
])
285-
.on_conflict_do_nothing(index_elements=['uid', 'source_id'])
286-
)
287-
session.execute(stmt)
275+
for i in range(0, len(user_ids), PG_BATCH_SIZE):
276+
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
277+
stmt = (
278+
postgresql_dialects.insert(AccessListStore)
279+
.values([
280+
{
281+
'uid': user_id,
282+
'source_id': source_id,
283+
}
284+
for user_id in batched_uids
285+
])
286+
.on_conflict_do_nothing(index_elements=['uid', 'source_id'])
287+
)
288+
session.execute(stmt)
289+
288290
session.commit()
289291
except SafeDbException as e:
290292
session.rollback()
@@ -324,27 +326,31 @@ def update_access(
324326

325327
match op:
326328
case UpdateAccessOp.allow:
327-
stmt = (
328-
postgresql_dialects.insert(AccessListStore)
329-
.values([
330-
{
331-
'uid': user_id,
332-
'source_id': source_id,
333-
}
334-
for user_id in user_ids
335-
])
336-
.on_conflict_do_nothing(index_elements=['uid', 'source_id'])
337-
)
338-
session.execute(stmt)
329+
for i in range(0, len(user_ids), PG_BATCH_SIZE):
330+
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
331+
stmt = (
332+
postgresql_dialects.insert(AccessListStore)
333+
.values([
334+
{
335+
'uid': user_id,
336+
'source_id': source_id,
337+
}
338+
for user_id in batched_uids
339+
])
340+
.on_conflict_do_nothing(index_elements=['uid', 'source_id'])
341+
)
342+
session.execute(stmt)
339343
session.commit()
340344

341345
case UpdateAccessOp.deny:
342-
stmt = (
343-
sa.delete(AccessListStore)
344-
.filter(AccessListStore.uid.in_(user_ids))
345-
.filter(AccessListStore.source_id == source_id)
346-
)
347-
session.execute(stmt)
346+
for i in range(0, len(user_ids), PG_BATCH_SIZE):
347+
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
348+
stmt = (
349+
sa.delete(AccessListStore)
350+
.filter(AccessListStore.uid.in_(batched_uids))
351+
.filter(AccessListStore.source_id == source_id)
352+
)
353+
session.execute(stmt)
348354
session.commit()
349355

350356
# check if all entries related to the source were deleted
@@ -356,6 +362,8 @@ def update_access(
356362
logger.info('Error: updating access list', exc_info=e, extra={
357363
'source_id': source_id,
358364
})
365+
except DbException:
366+
raise
359367
except Exception as e:
360368
session.rollback()
361369
raise DbException('Error: updating access list') from e
@@ -388,33 +396,35 @@ def _cleanup_if_orphaned(self, source_ids: list[str], session_: orm.Session | No
388396
if len(source_ids) == 0:
389397
return
390398

391-
filter_ = [
392-
AccessListStore.source_id.in_(source_ids) if len(source_ids) > 1
393-
else AccessListStore.source_id == source_ids[0]
394-
]
395-
396399
session = session_ or self.session_maker()
397400

398401
try:
399-
stmt = (
400-
sa.select(AccessListStore.source_id)
401-
.filter(*filter_)
402-
.distinct()
403-
)
404-
result = session.execute(stmt).fetchall()
402+
# find orphaned source_ids (no AccessListStore entry) in batches
403+
orphaned_ids = []
404+
for i in range(0, len(source_ids), PG_BATCH_SIZE):
405+
batched_ids = source_ids[i:i+PG_BATCH_SIZE]
406+
stmt = (
407+
sa.select(DocumentsStore.source_id)
408+
.filter(DocumentsStore.source_id.in_(batched_ids))
409+
.filter(
410+
~sa.exists( # NOT EXISTS
411+
sa.select(sa.literal(1))
412+
.where(AccessListStore.source_id == DocumentsStore.source_id)
413+
)
414+
)
415+
)
416+
result = session.execute(stmt).fetchall()
417+
orphaned_ids.extend(str(r.source_id) for r in result)
418+
419+
if len(orphaned_ids) > 0:
420+
self.delete_source_ids(orphaned_ids, session)
421+
except DbException:
422+
raise
405423
except Exception as e:
424+
raise DbException('Error: cleaning up orphaned source ids') from e
425+
finally:
406426
if session_ is None:
407427
session.close()
408-
raise DbException('Error: getting source ids from access list') from e
409-
410-
existing_links = [str(r.source_id) for r in result]
411-
to_delete = [source_id for source_id in source_ids if source_id not in existing_links]
412-
413-
if len(to_delete) > 0:
414-
self.delete_source_ids(to_delete, session_)
415-
416-
if session_ is None:
417-
session.close()
418428

419429
def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None = None):
420430
session = session_ or self.session_maker()
@@ -423,35 +433,32 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
423433
collection = self.client.get_collection(session)
424434

425435
# entry from "AccessListStore" is deleted automatically due to the foreign key constraint
426-
stmt_doc = (
427-
sa.delete(DocumentsStore)
428-
.filter(DocumentsStore.source_id.in_(source_ids))
429-
.returning(DocumentsStore.chunks)
430-
)
431-
432-
doc_result = session.execute(stmt_doc)
433-
chunks_to_delete = [str(c) for res in doc_result for c in res.chunks]
434-
except Exception as e:
435-
session.rollback()
436-
if session_ is None:
437-
session.close()
438-
raise DbException('Error: deleting source ids from docs store') from e
436+
# batch the deletion to avoid hitting the query parameter limit
437+
chunks_to_delete = []
438+
for i in range(0, len(source_ids), PG_BATCH_SIZE):
439+
batched_ids = source_ids[i:i+PG_BATCH_SIZE]
440+
stmt_doc = (
441+
sa.delete(DocumentsStore)
442+
.filter(DocumentsStore.source_id.in_(batched_ids))
443+
.returning(DocumentsStore.chunks)
444+
)
445+
doc_result = session.execute(stmt_doc)
446+
chunks_to_delete.extend(str(c) for res in doc_result for c in res.chunks)
439447

440-
try:
441-
stmt_chunks = (
442-
sa.delete(self.client.EmbeddingStore)
443-
.filter(self.client.EmbeddingStore.collection_id == collection.uuid)
444-
.filter(self.client.EmbeddingStore.id.in_(chunks_to_delete))
445-
)
448+
for i in range(0, len(chunks_to_delete), PG_BATCH_SIZE):
449+
batched_chunks = chunks_to_delete[i:i+PG_BATCH_SIZE]
450+
stmt_chunks = (
451+
sa.delete(self.client.EmbeddingStore)
452+
.filter(self.client.EmbeddingStore.collection_id == collection.uuid)
453+
.filter(self.client.EmbeddingStore.id.in_(batched_chunks))
454+
)
455+
session.execute(stmt_chunks)
446456

447-
session.execute(stmt_chunks)
448457
session.commit()
449458
except Exception as e:
450-
logger.error('Error deleting chunks, rolling back documents store deletion for source ids')
459+
logger.error('Error deleting source ids, rolling back changes.')
451460
session.rollback()
452-
raise DbException(
453-
'Error: deleting chunks, rolling back documents store deletion for source ids'
454-
) from e
461+
raise DbException('Error: deleting source ids, rolling back changes.') from e
455462
finally:
456463
if session_ is None:
457464
session.close()
@@ -469,23 +476,20 @@ def delete_provider(self, provider_key: str):
469476

470477
doc_result = session.execute(stmt)
471478
chunks_to_delete = [str(c) for res in doc_result for c in res.chunks]
472-
except Exception as e:
473-
session.rollback()
474-
raise DbException('Error: deleting provider from docs store') from e
475479

476-
try:
477-
stmt = (
478-
sa.delete(self.client.EmbeddingStore)
479-
.filter(self.client.EmbeddingStore.collection_id == collection.uuid)
480-
.filter(self.client.EmbeddingStore.id.in_(chunks_to_delete))
481-
)
482-
session.execute(stmt)
480+
for i in range(0, len(chunks_to_delete), PG_BATCH_SIZE):
481+
batched_chunks = chunks_to_delete[i:i+PG_BATCH_SIZE]
482+
stmt = (
483+
sa.delete(self.client.EmbeddingStore)
484+
.filter(self.client.EmbeddingStore.collection_id == collection.uuid)
485+
.filter(self.client.EmbeddingStore.id.in_(batched_chunks))
486+
)
487+
session.execute(stmt)
488+
483489
session.commit()
484490
except Exception as e:
485491
session.rollback()
486-
raise DbException(
487-
'Error: deleting chunks, rolling back documents store deletion for provider'
488-
) from e
492+
raise DbException('Error: deleting chunks, rolling back changes') from e
489493

490494
def delete_user(self, user_id: str):
491495
with self.session_maker() as session:

0 commit comments

Comments
 (0)