|
| 1 | +from collections.abc import Sequence |
1 | 2 | from datetime import datetime |
2 | 3 | from logging import getLogger |
3 | 4 | from typing import Any |
|
18 | 19 | logger = getLogger(__name__) |
19 | 20 |
|
20 | 21 |
|
| 22 | +def _deduplicate_messages( |
| 23 | + messages: Sequence[models.Message], limit: int |
| 24 | +) -> list[models.Message]: |
| 25 | + """Deduplicate messages by public_id, preserving input order.""" |
| 26 | + seen: set[str] = set() |
| 27 | + result: list[models.Message] = [] |
| 28 | + for msg in messages: |
| 29 | + if msg.public_id not in seen: |
| 30 | + seen.add(msg.public_id) |
| 31 | + result.append(msg) |
| 32 | + if len(result) >= limit: |
| 33 | + break |
| 34 | + return result |
| 35 | + |
| 36 | + |
21 | 37 | def _apply_token_limit( |
22 | 38 | base_conditions: list[ColumnElement[Any]], token_limit: int |
23 | 39 | ) -> Select[tuple[models.Message]]: |
@@ -578,6 +594,78 @@ async def update_message( |
578 | 594 | return honcho_message |
579 | 595 |
|
580 | 596 |
|
| 597 | +async def _search_messages_external( |
| 598 | + db: AsyncSession, |
| 599 | + workspace_name: str, |
| 600 | + query_embedding: list[float], |
| 601 | + limit: int, |
| 602 | + *, |
| 603 | + session_name: str | None = None, |
| 604 | + after_date: datetime | None = None, |
| 605 | + before_date: datetime | None = None, |
| 606 | +) -> list[models.Message]: |
| 607 | + """Query the external vector store for messages and fetch them from the DB. |
| 608 | +
|
| 609 | + Multiple vector records can map to the same message (chunked embeddings), |
| 610 | + so we oversample from the vector store and deduplicate by message_id. |
| 611 | +
|
| 612 | + Date filters are applied at the DB level since external vector stores |
| 613 | + don't support temporal filtering. |
| 614 | + """ |
| 615 | + external_vector_store = get_external_vector_store() |
| 616 | + if external_vector_store is None: |
| 617 | + return [] |
| 618 | + |
| 619 | + namespace = external_vector_store.get_vector_namespace("message", workspace_name) |
| 620 | + |
| 621 | + vector_filters: dict[str, Any] = {} |
| 622 | + if session_name: |
| 623 | + vector_filters["session_name"] = session_name |
| 624 | + |
| 625 | + # Oversample: chunks can map to the same message, and date filters are |
| 626 | + # applied post-fetch (vector stores don't support temporal filtering), |
| 627 | + # so fetch extra to compensate for both deduplication and filtering. |
| 628 | + has_date_filters = after_date is not None or before_date is not None |
| 629 | + oversample = 6 if has_date_filters else 3 |
| 630 | + vector_results = await external_vector_store.query( |
| 631 | + namespace, |
| 632 | + query_embedding, |
| 633 | + top_k=limit * oversample, |
| 634 | + filters=vector_filters if vector_filters else None, |
| 635 | + ) |
| 636 | + |
| 637 | + if not vector_results: |
| 638 | + return [] |
| 639 | + |
| 640 | + # Deduplicate by message_id preserving similarity order |
| 641 | + seen: dict[str, None] = {} |
| 642 | + for vr in vector_results: |
| 643 | + mid = vr.metadata.get("message_id") |
| 644 | + if mid and mid not in seen: |
| 645 | + seen[mid] = None |
| 646 | + message_ids = list(seen.keys()) |
| 647 | + |
| 648 | + if not message_ids: |
| 649 | + return [] |
| 650 | + |
| 651 | + # Fetch from DB with optional date filtering |
| 652 | + fetch_stmt = ( |
| 653 | + select(models.Message) |
| 654 | + .where(models.Message.public_id.in_(message_ids)) |
| 655 | + .where(models.Message.workspace_name == workspace_name) |
| 656 | + ) |
| 657 | + if after_date: |
| 658 | + fetch_stmt = fetch_stmt.where(models.Message.created_at >= after_date) |
| 659 | + if before_date: |
| 660 | + fetch_stmt = fetch_stmt.where(models.Message.created_at <= before_date) |
| 661 | + |
| 662 | + result = await db.execute(fetch_stmt) |
| 663 | + messages_by_id = {msg.public_id: msg for msg in result.scalars().all()} |
| 664 | + |
| 665 | + # Preserve vector store similarity order, apply limit |
| 666 | + return [messages_by_id[mid] for mid in message_ids if mid in messages_by_id][:limit] |
| 667 | + |
| 668 | + |
581 | 669 | async def search_messages( |
582 | 670 | db: AsyncSession, |
583 | 671 | workspace_name: str, |
@@ -612,25 +700,36 @@ async def search_messages( |
612 | 700 | embedding if embedding is not None else await embedding_client.embed(query) |
613 | 701 | ) |
614 | 702 |
|
615 | | - # First, find the top matching messages |
616 | | - match_stmt = ( |
617 | | - select(models.Message) |
618 | | - .join( |
619 | | - models.MessageEmbedding, |
620 | | - models.Message.public_id == models.MessageEmbedding.message_id, |
| 703 | + if settings.VECTOR_STORE.TYPE == "pgvector" or not settings.VECTOR_STORE.MIGRATED: |
| 704 | + # pgvector path: cosine distance in SQL |
| 705 | + # Oversample because a message with multiple embedding chunks can |
| 706 | + # produce duplicate rows; we deduplicate in Python to preserve HNSW |
| 707 | + # index usage (a DISTINCT ON subquery would prevent the index scan). |
| 708 | + match_stmt = ( |
| 709 | + select(models.Message) |
| 710 | + .join( |
| 711 | + models.MessageEmbedding, |
| 712 | + models.Message.public_id == models.MessageEmbedding.message_id, |
| 713 | + ) |
| 714 | + .where(models.MessageEmbedding.workspace_name == workspace_name) |
| 715 | + .order_by( |
| 716 | + models.MessageEmbedding.embedding.cosine_distance(query_embedding) |
| 717 | + ) |
| 718 | + .limit(limit * 2) |
621 | 719 | ) |
622 | | - .where(models.MessageEmbedding.workspace_name == workspace_name) |
623 | | - .order_by(models.MessageEmbedding.embedding.cosine_distance(query_embedding)) |
624 | | - .limit(limit) |
625 | | - ) |
626 | 720 |
|
627 | | - if session_name: |
628 | | - match_stmt = match_stmt.where( |
629 | | - models.MessageEmbedding.session_name == session_name |
630 | | - ) |
| 721 | + if session_name: |
| 722 | + match_stmt = match_stmt.where( |
| 723 | + models.MessageEmbedding.session_name == session_name |
| 724 | + ) |
631 | 725 |
|
632 | | - result = await db.execute(match_stmt) |
633 | | - matched_messages = list(result.scalars().all()) |
| 726 | + result = await db.execute(match_stmt) |
| 727 | + matched_messages = _deduplicate_messages(result.scalars().all(), limit) |
| 728 | + else: |
| 729 | + # External vector store path |
| 730 | + matched_messages = await _search_messages_external( |
| 731 | + db, workspace_name, query_embedding, limit, session_name=session_name |
| 732 | + ) |
634 | 733 |
|
635 | 734 | return await _build_merged_snippets( |
636 | 735 | db, workspace_name, matched_messages, context_window |
@@ -767,34 +866,47 @@ async def search_messages_temporal( |
767 | 866 | embedding if embedding is not None else await embedding_client.embed(query) |
768 | 867 | ) |
769 | 868 |
|
770 | | - # Build query with date filters |
771 | | - match_stmt = ( |
772 | | - select(models.Message) |
773 | | - .join( |
774 | | - models.MessageEmbedding, |
775 | | - models.Message.public_id == models.MessageEmbedding.message_id, |
| 869 | + if settings.VECTOR_STORE.TYPE == "pgvector" or not settings.VECTOR_STORE.MIGRATED: |
| 870 | + # pgvector path: cosine distance in SQL with date filters |
| 871 | + # Oversample to handle chunk duplicates (see search_messages comment) |
| 872 | + match_stmt = ( |
| 873 | + select(models.Message) |
| 874 | + .join( |
| 875 | + models.MessageEmbedding, |
| 876 | + models.Message.public_id == models.MessageEmbedding.message_id, |
| 877 | + ) |
| 878 | + .where(models.MessageEmbedding.workspace_name == workspace_name) |
776 | 879 | ) |
777 | | - .where(models.MessageEmbedding.workspace_name == workspace_name) |
778 | | - ) |
779 | 880 |
|
780 | | - if session_name: |
781 | | - match_stmt = match_stmt.where( |
782 | | - models.MessageEmbedding.session_name == session_name |
783 | | - ) |
| 881 | + if session_name: |
| 882 | + match_stmt = match_stmt.where( |
| 883 | + models.MessageEmbedding.session_name == session_name |
| 884 | + ) |
784 | 885 |
|
785 | | - # Apply date filters on the Message table |
786 | | - if after_date: |
787 | | - match_stmt = match_stmt.where(models.Message.created_at >= after_date) |
788 | | - if before_date: |
789 | | - match_stmt = match_stmt.where(models.Message.created_at <= before_date) |
| 886 | + # Apply date filters on the Message table |
| 887 | + if after_date: |
| 888 | + match_stmt = match_stmt.where(models.Message.created_at >= after_date) |
| 889 | + if before_date: |
| 890 | + match_stmt = match_stmt.where(models.Message.created_at <= before_date) |
790 | 891 |
|
791 | | - # Order by similarity and limit |
792 | | - match_stmt = match_stmt.order_by( |
793 | | - models.MessageEmbedding.embedding.cosine_distance(query_embedding) |
794 | | - ).limit(limit) |
| 892 | + # Order by similarity and limit |
| 893 | + match_stmt = match_stmt.order_by( |
| 894 | + models.MessageEmbedding.embedding.cosine_distance(query_embedding) |
| 895 | + ).limit(limit * 2) |
795 | 896 |
|
796 | | - result = await db.execute(match_stmt) |
797 | | - matched_messages = list(result.scalars().all()) |
| 897 | + result = await db.execute(match_stmt) |
| 898 | + matched_messages = _deduplicate_messages(result.scalars().all(), limit) |
| 899 | + else: |
| 900 | + # External vector store path with post-fetch date filtering |
| 901 | + matched_messages = await _search_messages_external( |
| 902 | + db, |
| 903 | + workspace_name, |
| 904 | + query_embedding, |
| 905 | + limit, |
| 906 | + session_name=session_name, |
| 907 | + after_date=after_date, |
| 908 | + before_date=before_date, |
| 909 | + ) |
798 | 910 |
|
799 | 911 | return await _build_merged_snippets( |
800 | 912 | db, workspace_name, matched_messages, context_window |
|
0 commit comments