Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions config/graphql/discover_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
DISCOVER_DEFAULT_LIMIT,
DISCOVER_OVERSAMPLE,
DISCOVER_QUERY_VECTOR_CACHE_SIZE,
DISCOVER_TEXT_SEARCH_MAX_LENGTH,
FTS_CONFIG,
RRF_K,
)
Expand Down Expand Up @@ -115,6 +116,18 @@ def _default_embedder_path() -> Optional[str]:
return get_default_embedder_path()


def _normalise_text_search(text_search: Optional[str]) -> Optional[str]:
"""Strip and validate a Discover search string before any search arm runs."""
text = (text_search or "").strip()
if not text or len(text) > DISCOVER_TEXT_SEARCH_MAX_LENGTH:
return None
return text


class _UncacheableQueryVector(Exception):
"""Raised inside the LRU wrapper so failed embeddings are not cached."""


def _query_vector(query_text: str, embedder_path: Optional[str]) -> Optional[list]:
"""Embed ``query_text`` with the default embedder, or ``None`` on failure.

Expand Down Expand Up @@ -144,12 +157,15 @@ def _cached_query_vector(query_text: str, embedder_path: str) -> Optional[list]:

Caveats (acceptable for a best-effort arm): there is no TTL, so a vector
lives until LRU-evicted — fine, because the same inputs always produce the
same vector. A transient embedder failure (``None``) is also cached for the
LRU window; the consequence is text-only results for that exact query until
eviction, never a wrong result, and the text arm always returns on its own.
same vector. Failed embeddings are deliberately not cached: callers catch
``_UncacheableQueryVector`` and fall back to text-only results so transient
failures do not pin attacker-controlled query strings in worker memory.
Tests reset the cache in ``setUp`` (``_cached_query_vector.cache_clear()``).
"""
return _query_vector(query_text, embedder_path)
vector = _query_vector(query_text, embedder_path)
if not vector:
raise _UncacheableQueryVector
return vector


def _text_ids(
Expand Down Expand Up @@ -200,8 +216,9 @@ def _semantic_ids(
# No embedder configured → semantic arm is a no-op. Guard here (rather
# than relying on the cache) so we never seed the LRU with a null key.
return []
vector = _cached_query_vector(query_text, embedder_path)
if not vector:
try:
vector = _cached_query_vector(query_text, embedder_path)
except _UncacheableQueryVector:
return []
try:
results = visible_qs.search_by_embedding( # type: ignore[attr-defined]
Expand Down Expand Up @@ -287,7 +304,7 @@ class DiscoverSearchQueryMixin:
def resolve_discover_annotations(
self, info, text_search, limit=DISCOVER_DEFAULT_LIMIT
) -> Any:
text = (text_search or "").strip()
text = _normalise_text_search(text_search)
if not text:
return []
limit = _clamp_limit(limit)
Expand Down Expand Up @@ -323,7 +340,7 @@ def resolve_discover_annotations(
def resolve_discover_documents(
self, info, text_search, limit=DISCOVER_DEFAULT_LIMIT
) -> Any:
text = (text_search or "").strip()
text = _normalise_text_search(text_search)
if not text:
return []
limit = _clamp_limit(limit)
Expand All @@ -347,7 +364,7 @@ def resolve_discover_documents(
def resolve_discover_notes(
self, info, text_search, limit=DISCOVER_DEFAULT_LIMIT
) -> Any:
text = (text_search or "").strip()
text = _normalise_text_search(text_search)
if not text:
return []
limit = _clamp_limit(limit)
Expand Down Expand Up @@ -379,7 +396,7 @@ def resolve_discover_notes(
def resolve_discover_corpuses(
self, info, text_search, limit=DISCOVER_DEFAULT_LIMIT
) -> Any:
text = (text_search or "").strip()
text = _normalise_text_search(text_search)
if not text:
return []
limit = _clamp_limit(limit)
Expand Down Expand Up @@ -462,7 +479,7 @@ def resolve_discover_corpuses(
def resolve_discover_discussions(
self, info, text_search, limit=DISCOVER_DEFAULT_LIMIT
) -> Any:
text = (text_search or "").strip()
text = _normalise_text_search(text_search)
if not text:
return []
limit = _clamp_limit(limit)
Expand Down
7 changes: 7 additions & 0 deletions opencontractserver/constants/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@
# small module-level cache lets those requests share one embedding call.
DISCOVER_QUERY_VECTOR_CACHE_SIZE = 32

# Maximum accepted Discover text search length. Keep this intentionally small
# because the same user-controlled string is used for database text predicates
# and, when semantic search is enabled, as input to the per-process query-vector
# cache. Longer GraphQL values are ignored before reaching either arm so workers
# never retain attacker-sized strings in cache keys.
DISCOVER_TEXT_SEARCH_MAX_LENGTH = 512

# Extra oversample applied to the corpus "content match" pre-filters
# (documents/annotations whose text matches), on top of ``fetch_k``. A corpus
# is reached transitively through many matching documents/annotations, so this
Expand Down
26 changes: 26 additions & 0 deletions opencontractserver/tests/test_discover_search_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,18 @@ def test_discover_discussions_matches_message_body_not_just_title(self):
# CHAT-type conversation is excluded even though its title matches.
self.assertNotIn("indemnification chat", titles)

def test_overlong_query_returns_empty_before_searching(self):
from opencontractserver.constants.search import DISCOVER_TEXT_SEARCH_MAX_LENGTH

with patch("config.graphql.discover_queries._query_vector") as query_vector:
result = self.graphene_client.execute(
"query D($t: String!){ discoverAnnotations(textSearch:$t){ id } }",
variables={"t": "x" * (DISCOVER_TEXT_SEARCH_MAX_LENGTH + 1)},
)
self.assertIsNone(result.get("errors"), result.get("errors"))
self.assertEqual(result["data"]["discoverAnnotations"], [])
query_vector.assert_not_called()

def test_empty_query_returns_empty(self):
for field in (
"discoverAnnotations",
Expand Down Expand Up @@ -403,6 +415,20 @@ def test_clamp_limit_caps_at_semantic_max(self):
SEMANTIC_SEARCH_MAX_RESULTS,
)

def test_failed_query_vectors_are_not_cached(self):
from config.graphql.discover_queries import (
_cached_query_vector,
_UncacheableQueryVector,
)

_cached_query_vector.cache_clear()
self.addCleanup(_cached_query_vector.cache_clear)
with patch("config.graphql.discover_queries._query_vector", return_value=None):
with self.assertRaises(_UncacheableQueryVector):
_cached_query_vector("uncacheable failure", "embedder")

self.assertEqual(_cached_query_vector.cache_info().currsize, 0)

def test_rrf_tie_break_is_deterministic_and_type_agnostic(self):
from config.graphql.discover_queries import _rrf

Expand Down