Skip to content

Commit d803c54

Browse files
authored
fix: (search) add logic to use external vectore store for message search (plastic-labs#479)
* fix: (search) add logic to use external vectore store for message search * fix: (search) oversample to reduce duplicate errors
1 parent 0533c6d commit d803c54

1 file changed

Lines changed: 151 additions & 39 deletions

File tree

src/crud/message.py

Lines changed: 151 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from datetime import datetime
23
from logging import getLogger
34
from typing import Any
@@ -18,6 +19,21 @@
1819
logger = getLogger(__name__)
1920

2021

22+
def _deduplicate_messages(
23+
messages: Sequence[models.Message], limit: int
24+
) -> list[models.Message]:
25+
"""Deduplicate messages by public_id, preserving input order."""
26+
seen: set[str] = set()
27+
result: list[models.Message] = []
28+
for msg in messages:
29+
if msg.public_id not in seen:
30+
seen.add(msg.public_id)
31+
result.append(msg)
32+
if len(result) >= limit:
33+
break
34+
return result
35+
36+
2137
def _apply_token_limit(
2238
base_conditions: list[ColumnElement[Any]], token_limit: int
2339
) -> Select[tuple[models.Message]]:
@@ -578,6 +594,78 @@ async def update_message(
578594
return honcho_message
579595

580596

597+
async def _search_messages_external(
598+
db: AsyncSession,
599+
workspace_name: str,
600+
query_embedding: list[float],
601+
limit: int,
602+
*,
603+
session_name: str | None = None,
604+
after_date: datetime | None = None,
605+
before_date: datetime | None = None,
606+
) -> list[models.Message]:
607+
"""Query the external vector store for messages and fetch them from the DB.
608+
609+
Multiple vector records can map to the same message (chunked embeddings),
610+
so we oversample from the vector store and deduplicate by message_id.
611+
612+
Date filters are applied at the DB level since external vector stores
613+
don't support temporal filtering.
614+
"""
615+
external_vector_store = get_external_vector_store()
616+
if external_vector_store is None:
617+
return []
618+
619+
namespace = external_vector_store.get_vector_namespace("message", workspace_name)
620+
621+
vector_filters: dict[str, Any] = {}
622+
if session_name:
623+
vector_filters["session_name"] = session_name
624+
625+
# Oversample: chunks can map to the same message, and date filters are
626+
# applied post-fetch (vector stores don't support temporal filtering),
627+
# so fetch extra to compensate for both deduplication and filtering.
628+
has_date_filters = after_date is not None or before_date is not None
629+
oversample = 6 if has_date_filters else 3
630+
vector_results = await external_vector_store.query(
631+
namespace,
632+
query_embedding,
633+
top_k=limit * oversample,
634+
filters=vector_filters if vector_filters else None,
635+
)
636+
637+
if not vector_results:
638+
return []
639+
640+
# Deduplicate by message_id preserving similarity order
641+
seen: dict[str, None] = {}
642+
for vr in vector_results:
643+
mid = vr.metadata.get("message_id")
644+
if mid and mid not in seen:
645+
seen[mid] = None
646+
message_ids = list(seen.keys())
647+
648+
if not message_ids:
649+
return []
650+
651+
# Fetch from DB with optional date filtering
652+
fetch_stmt = (
653+
select(models.Message)
654+
.where(models.Message.public_id.in_(message_ids))
655+
.where(models.Message.workspace_name == workspace_name)
656+
)
657+
if after_date:
658+
fetch_stmt = fetch_stmt.where(models.Message.created_at >= after_date)
659+
if before_date:
660+
fetch_stmt = fetch_stmt.where(models.Message.created_at <= before_date)
661+
662+
result = await db.execute(fetch_stmt)
663+
messages_by_id = {msg.public_id: msg for msg in result.scalars().all()}
664+
665+
# Preserve vector store similarity order, apply limit
666+
return [messages_by_id[mid] for mid in message_ids if mid in messages_by_id][:limit]
667+
668+
581669
async def search_messages(
582670
db: AsyncSession,
583671
workspace_name: str,
@@ -612,25 +700,36 @@ async def search_messages(
612700
embedding if embedding is not None else await embedding_client.embed(query)
613701
)
614702

615-
# First, find the top matching messages
616-
match_stmt = (
617-
select(models.Message)
618-
.join(
619-
models.MessageEmbedding,
620-
models.Message.public_id == models.MessageEmbedding.message_id,
703+
if settings.VECTOR_STORE.TYPE == "pgvector" or not settings.VECTOR_STORE.MIGRATED:
704+
# pgvector path: cosine distance in SQL
705+
# Oversample because a message with multiple embedding chunks can
706+
# produce duplicate rows; we deduplicate in Python to preserve HNSW
707+
# index usage (a DISTINCT ON subquery would prevent the index scan).
708+
match_stmt = (
709+
select(models.Message)
710+
.join(
711+
models.MessageEmbedding,
712+
models.Message.public_id == models.MessageEmbedding.message_id,
713+
)
714+
.where(models.MessageEmbedding.workspace_name == workspace_name)
715+
.order_by(
716+
models.MessageEmbedding.embedding.cosine_distance(query_embedding)
717+
)
718+
.limit(limit * 2)
621719
)
622-
.where(models.MessageEmbedding.workspace_name == workspace_name)
623-
.order_by(models.MessageEmbedding.embedding.cosine_distance(query_embedding))
624-
.limit(limit)
625-
)
626720

627-
if session_name:
628-
match_stmt = match_stmt.where(
629-
models.MessageEmbedding.session_name == session_name
630-
)
721+
if session_name:
722+
match_stmt = match_stmt.where(
723+
models.MessageEmbedding.session_name == session_name
724+
)
631725

632-
result = await db.execute(match_stmt)
633-
matched_messages = list(result.scalars().all())
726+
result = await db.execute(match_stmt)
727+
matched_messages = _deduplicate_messages(result.scalars().all(), limit)
728+
else:
729+
# External vector store path
730+
matched_messages = await _search_messages_external(
731+
db, workspace_name, query_embedding, limit, session_name=session_name
732+
)
634733

635734
return await _build_merged_snippets(
636735
db, workspace_name, matched_messages, context_window
@@ -767,34 +866,47 @@ async def search_messages_temporal(
767866
embedding if embedding is not None else await embedding_client.embed(query)
768867
)
769868

770-
# Build query with date filters
771-
match_stmt = (
772-
select(models.Message)
773-
.join(
774-
models.MessageEmbedding,
775-
models.Message.public_id == models.MessageEmbedding.message_id,
869+
if settings.VECTOR_STORE.TYPE == "pgvector" or not settings.VECTOR_STORE.MIGRATED:
870+
# pgvector path: cosine distance in SQL with date filters
871+
# Oversample to handle chunk duplicates (see search_messages comment)
872+
match_stmt = (
873+
select(models.Message)
874+
.join(
875+
models.MessageEmbedding,
876+
models.Message.public_id == models.MessageEmbedding.message_id,
877+
)
878+
.where(models.MessageEmbedding.workspace_name == workspace_name)
776879
)
777-
.where(models.MessageEmbedding.workspace_name == workspace_name)
778-
)
779880

780-
if session_name:
781-
match_stmt = match_stmt.where(
782-
models.MessageEmbedding.session_name == session_name
783-
)
881+
if session_name:
882+
match_stmt = match_stmt.where(
883+
models.MessageEmbedding.session_name == session_name
884+
)
784885

785-
# Apply date filters on the Message table
786-
if after_date:
787-
match_stmt = match_stmt.where(models.Message.created_at >= after_date)
788-
if before_date:
789-
match_stmt = match_stmt.where(models.Message.created_at <= before_date)
886+
# Apply date filters on the Message table
887+
if after_date:
888+
match_stmt = match_stmt.where(models.Message.created_at >= after_date)
889+
if before_date:
890+
match_stmt = match_stmt.where(models.Message.created_at <= before_date)
790891

791-
# Order by similarity and limit
792-
match_stmt = match_stmt.order_by(
793-
models.MessageEmbedding.embedding.cosine_distance(query_embedding)
794-
).limit(limit)
892+
# Order by similarity and limit
893+
match_stmt = match_stmt.order_by(
894+
models.MessageEmbedding.embedding.cosine_distance(query_embedding)
895+
).limit(limit * 2)
795896

796-
result = await db.execute(match_stmt)
797-
matched_messages = list(result.scalars().all())
897+
result = await db.execute(match_stmt)
898+
matched_messages = _deduplicate_messages(result.scalars().all(), limit)
899+
else:
900+
# External vector store path with post-fetch date filtering
901+
matched_messages = await _search_messages_external(
902+
db,
903+
workspace_name,
904+
query_embedding,
905+
limit,
906+
session_name=session_name,
907+
after_date=after_date,
908+
before_date=before_date,
909+
)
798910

799911
return await _build_merged_snippets(
800912
db, workspace_name, matched_messages, context_window

0 commit comments

Comments
 (0)