Skip to content

Commit 0533c6d

Browse files
fix dialectic held connection (plastic-labs#477)
* fix: dialectic held connection * fix: (agent) pre-compute embeddings for agent tools * fix: (tests) refactor tests to use smaller test db connections * fix: Embedding client to branch depending on vector store * fix: reflect dedup-skipped observations in created counts and isolate DB sessions in extract_preferences * fix: (tests) update tests to match changes * fix: expunge docs + don't pass in db to query_documents --------- Co-authored-by: Rajat Ahuja <rahuja445@gmail.com>
1 parent cc3483b commit 0533c6d

15 files changed

Lines changed: 659 additions & 515 deletions

File tree

src/crud/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
create_observations,
1010
delete_document,
1111
delete_document_by_id,
12+
fetch_documents_by_ids,
1213
get_all_documents,
1314
get_child_observations,
1415
get_documents_by_ids,
1516
get_documents_with_filters,
1617
query_documents,
1718
query_documents_most_derived,
1819
query_documents_recent,
20+
query_external_vector_document_ids,
1921
)
2022
from .message import (
2123
create_messages,
@@ -82,13 +84,15 @@
8284
# Document
8385
"create_documents",
8486
"create_observations",
87+
"fetch_documents_by_ids",
8588
"get_all_documents",
8689
"get_child_observations",
8790
"get_documents_by_ids",
8891
"get_documents_with_filters",
8992
"query_documents",
9093
"query_documents_most_derived",
9194
"query_documents_recent",
95+
"query_external_vector_document_ids",
9296
"delete_document",
9397
"delete_document_by_id",
9498
# Message

src/crud/document.py

Lines changed: 181 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from src.crud.collection import get_or_create_collection
1616
from src.crud.peer import get_peer
1717
from src.crud.session import get_session
18+
from src.dependencies import tracked_db
1819
from src.embedding_client import embedding_client
1920
from src.exceptions import ResourceNotFoundException, ValidationException
2021
from src.utils.filter import apply_filter
@@ -190,70 +191,34 @@ async def query_documents_most_derived(
190191
return result.scalars().all()
191192

192193

193-
async def query_documents(
194-
db: AsyncSession,
194+
def _uses_pgvector() -> bool:
195+
"""Check whether queries should go through pgvector (DB-only) path."""
196+
return (
197+
settings.VECTOR_STORE.TYPE == "pgvector" or not settings.VECTOR_STORE.MIGRATED
198+
)
199+
200+
201+
async def query_external_vector_document_ids(
195202
workspace_name: str,
196-
query: str,
197-
*,
198203
observer: str,
199204
observed: str,
200-
filters: dict[str, Any] | None = None,
201-
max_distance: float | None = None,
205+
embedding: list[float],
202206
top_k: int = 5,
203-
embedding: list[float] | None = None,
204-
) -> Sequence[models.Document]:
205-
"""
206-
Query documents using semantic similarity.
207+
max_distance: float | None = None,
208+
filters: dict[str, Any] | None = None,
209+
) -> list[str] | None:
210+
"""Query external vector store for document IDs sorted by similarity.
207211
208-
Args:
209-
db: Database session
210-
workspace_name: Name of the workspace
211-
query: Search query text
212-
observer: Name of the observing peer
213-
observed: Name of the observed peer
214-
filters: Optional filters to apply at vector store level (supports: level, session_name)
215-
max_distance: Maximum cosine distance for results
216-
top_k: Number of results to return
217-
embedding: Optional pre-computed embedding for the query (avoids extra API call if possible)
212+
No DB session needed — safe to call outside a tracked_db scope.
218213
219214
Returns:
220-
Sequence of matching documents
215+
Ordered list of document IDs on the external-store path,
216+
empty list when the external store has no results,
217+
or None when the pgvector (DB-only) path should be used instead.
221218
"""
222-
# Use provided embedding or generate one
223-
if embedding is None:
224-
try:
225-
embedding = await embedding_client.embed(query)
226-
except ValueError as e:
227-
raise ValidationException(
228-
f"Query exceeds maximum token limit of {settings.MAX_EMBEDDING_TOKENS}."
229-
) from e
219+
if _uses_pgvector():
220+
return None
230221

231-
# Query Postgres directly when using pgvector OR during migration (not yet migrated)
232-
# This ensures we use pgvector as source of truth until migration is complete
233-
if settings.VECTOR_STORE.TYPE == "pgvector" or not settings.VECTOR_STORE.MIGRATED:
234-
stmt = (
235-
select(models.Document)
236-
.where(models.Document.workspace_name == workspace_name)
237-
.where(models.Document.observer == observer)
238-
.where(models.Document.observed == observed)
239-
.where(models.Document.embedding.isnot(None))
240-
.where(models.Document.deleted_at.is_(None))
241-
)
242-
243-
if max_distance is not None:
244-
stmt = stmt.where(
245-
models.Document.embedding.cosine_distance(embedding) <= max_distance
246-
)
247-
248-
stmt = apply_filter(stmt, models.Document, filters)
249-
stmt = stmt.order_by(
250-
models.Document.embedding.cosine_distance(embedding)
251-
).limit(top_k)
252-
253-
result = await db.execute(stmt)
254-
return list(result.scalars().all())
255-
256-
# FALLBACK: Use external vector store (Turbopuffer, LanceDB)
257222
external_vector_store = get_external_vector_store()
258223
if external_vector_store is None:
259224
return []
@@ -262,18 +227,12 @@ async def query_documents(
262227
"document", workspace_name, observer, observed
263228
)
264229

265-
# Build vector store filters
266-
# Convert filter dict to vector store format (handles level, session_name, etc.)
267230
vector_filters: dict[str, Any] = {}
268231
if filters:
269-
# Direct pass-through for simple equality filters
270-
# The filters dict can contain: level, session_name, or other document fields
271-
# We can push level and session_name to vector store since they're in metadata
272232
for key in ["level", "session_name"]:
273233
if key in filters:
274234
vector_filters[key] = filters[key]
275235

276-
# Query external vector store for similar documents with filters applied
277236
vector_results = await external_vector_store.query(
278237
namespace,
279238
embedding,
@@ -285,10 +244,21 @@ async def query_documents(
285244
if not vector_results:
286245
return []
287246

288-
# Get document IDs from vector results (vector ID = document ID for documents)
289-
document_ids = [result.id for result in vector_results]
247+
return [result.id for result in vector_results]
248+
249+
250+
async def fetch_documents_by_ids(
251+
db: AsyncSession,
252+
workspace_name: str,
253+
observer: str,
254+
observed: str,
255+
document_ids: list[str],
256+
filters: dict[str, Any] | None = None,
257+
) -> list[models.Document]:
258+
"""Fetch documents by IDs, preserving input order. DB-only operation."""
259+
if not document_ids:
260+
return []
290261

291-
# Fetch documents from database
292262
stmt = (
293263
select(models.Document)
294264
.where(models.Document.workspace_name == workspace_name)
@@ -297,20 +267,153 @@ async def query_documents(
297267
.where(models.Document.deleted_at.is_(None))
298268
.where(models.Document.id.in_(document_ids))
299269
)
300-
# Re-apply all filters at the database layer to catch any constraints
301-
# that aren't supported by the vector store metadata.
302270
stmt = apply_filter(stmt, models.Document, filters)
303271

304272
result = await db.execute(stmt)
305273
documents = {doc.id: doc for doc in result.scalars().all()}
306274

307-
# Return documents in order of similarity (preserving vector store order)
308-
ordered_docs: list[models.Document] = []
309-
for vr in vector_results:
310-
if vr.id in documents:
311-
ordered_docs.append(documents[vr.id])
275+
return [documents[doc_id] for doc_id in document_ids if doc_id in documents]
312276

313-
return ordered_docs
277+
278+
async def _query_documents_pgvector(
279+
db: AsyncSession,
280+
workspace_name: str,
281+
observer: str,
282+
observed: str,
283+
embedding: list[float],
284+
filters: dict[str, Any] | None,
285+
max_distance: float | None,
286+
top_k: int,
287+
) -> list[models.Document]:
288+
"""pgvector similarity search — pure DB operation."""
289+
stmt = (
290+
select(models.Document)
291+
.where(models.Document.workspace_name == workspace_name)
292+
.where(models.Document.observer == observer)
293+
.where(models.Document.observed == observed)
294+
.where(models.Document.embedding.isnot(None))
295+
.where(models.Document.deleted_at.is_(None))
296+
)
297+
298+
if max_distance is not None:
299+
stmt = stmt.where(
300+
models.Document.embedding.cosine_distance(embedding) <= max_distance
301+
)
302+
303+
stmt = apply_filter(stmt, models.Document, filters)
304+
stmt = stmt.order_by(models.Document.embedding.cosine_distance(embedding)).limit(
305+
top_k
306+
)
307+
308+
result = await db.execute(stmt)
309+
return list(result.scalars().all())
310+
311+
312+
async def query_documents(
313+
db: AsyncSession | None,
314+
workspace_name: str,
315+
query: str,
316+
*,
317+
observer: str,
318+
observed: str,
319+
filters: dict[str, Any] | None = None,
320+
max_distance: float | None = None,
321+
top_k: int = 5,
322+
embedding: list[float] | None = None,
323+
) -> Sequence[models.Document]:
324+
"""
325+
Query documents using semantic similarity.
326+
327+
When *db* is provided the caller owns the session lifetime. When *db* is
328+
``None`` the function opens (and closes) its own short-lived session so that
329+
no DB connection is held during external vector-store calls.
330+
331+
Args:
332+
db: Database session, or None to let the function manage its own
333+
workspace_name: Name of the workspace
334+
query: Search query text
335+
observer: Name of the observing peer
336+
observed: Name of the observed peer
337+
filters: Optional filters to apply at vector store level (supports: level, session_name)
338+
max_distance: Maximum cosine distance for results
339+
top_k: Number of results to return
340+
embedding: Optional pre-computed embedding for the query (avoids extra API call if possible)
341+
342+
Returns:
343+
Sequence of matching documents
344+
"""
345+
# Use provided embedding or generate one
346+
if embedding is None:
347+
try:
348+
embedding = await embedding_client.embed(query)
349+
except ValueError as e:
350+
raise ValidationException(
351+
f"Query exceeds maximum token limit of {settings.MAX_EMBEDDING_TOKENS}."
352+
) from e
353+
354+
if _uses_pgvector():
355+
# pgvector path — pure DB, open a short session if none provided
356+
if db is not None:
357+
return await _query_documents_pgvector(
358+
db,
359+
workspace_name,
360+
observer,
361+
observed,
362+
embedding,
363+
filters,
364+
max_distance,
365+
top_k,
366+
)
367+
async with tracked_db("query_documents.pgvector") as managed_db:
368+
docs = await _query_documents_pgvector(
369+
managed_db,
370+
workspace_name,
371+
observer,
372+
observed,
373+
embedding,
374+
filters,
375+
max_distance,
376+
top_k,
377+
)
378+
for doc in docs:
379+
managed_db.expunge(doc)
380+
return docs
381+
382+
# External vector store — network call first, DB only for the ID fetch
383+
document_ids = await query_external_vector_document_ids(
384+
workspace_name=workspace_name,
385+
observer=observer,
386+
observed=observed,
387+
embedding=embedding,
388+
top_k=top_k,
389+
max_distance=max_distance,
390+
filters=filters,
391+
)
392+
393+
if not document_ids:
394+
return []
395+
396+
if db is not None:
397+
return await fetch_documents_by_ids(
398+
db=db,
399+
workspace_name=workspace_name,
400+
observer=observer,
401+
observed=observed,
402+
document_ids=document_ids,
403+
filters=filters,
404+
)
405+
async with tracked_db("query_documents.fetch") as managed_db:
406+
docs = await fetch_documents_by_ids(
407+
db=managed_db,
408+
workspace_name=workspace_name,
409+
observer=observer,
410+
observed=observed,
411+
document_ids=document_ids,
412+
filters=filters,
413+
)
414+
for doc in docs:
415+
managed_db.expunge(doc)
416+
return docs
314417

315418

316419
async def create_documents(
@@ -321,7 +424,7 @@ async def create_documents(
321424
observer: str,
322425
observed: str,
323426
deduplicate: bool = False,
324-
) -> int:
427+
) -> list[schemas.DocumentCreate]:
325428
"""
326429
Create multiple documents with optional duplicate detection.
327430
@@ -333,9 +436,11 @@ async def create_documents(
333436
observed: Name of the observed peer
334437
335438
Returns:
336-
Count of new documents
439+
List of DocumentCreate schemas that were actually inserted (excludes
440+
duplicates and failures).
337441
"""
338442
honcho_documents: list[models.Document] = []
443+
accepted_documents: list[schemas.DocumentCreate] = []
339444
# Store (document_model, embedding) pairs - IDs aren't available until after commit
340445
docs_with_embeddings: list[tuple[models.Document, list[float]]] = []
341446

@@ -391,6 +496,7 @@ async def create_documents(
391496
if doc.embedding:
392497
new_doc.sync_state = "pending"
393498
honcho_documents.append(new_doc)
499+
accepted_documents.append(doc)
394500

395501
# Track embedding for vector store (ID will be available after commit)
396502
if doc.embedding:
@@ -489,7 +595,7 @@ async def create_documents(
489595
"Failed to create documents due to integrity constraint violation"
490596
) from e
491597

492-
return len(honcho_documents)
598+
return accepted_documents
493599

494600

495601
async def delete_document(

0 commit comments

Comments
 (0)