Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.

Commit 7c6d8d9

Browse files
Harden raw search score handling
1 parent bff1ba8 commit 7c6d8d9

4 files changed

Lines changed: 75 additions & 17 deletions

File tree

src/api/routes/memory.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import asyncio
1010
import logging
11+
import math
12+
import threading
1113
import time
1214
from typing import Any, Dict, List
1315

@@ -113,6 +115,14 @@ def _error(request: Request, detail: str, code: int, elapsed_ms: float = 0) -> J
113115
return JSONResponse(content=body.model_dump(), status_code=code)
114116

115117

118+
def _safe_score(score: Any) -> float:
119+
try:
120+
value = float(score)
121+
except (TypeError, ValueError):
122+
return 0.0
123+
return value if math.isfinite(value) else 0.0
124+
125+
116126
def _detect_chat_provider(*urls: str) -> str:
117127
for url in urls:
118128
lowered = (url or "").lower()
@@ -150,8 +160,6 @@ async def _render_chat_share(url: str) -> tuple[str, str]:
150160
# reuse it across scrape requests. The browser is thread-safe when each
151161
# request uses its own BrowserContext.
152162

153-
import threading
154-
155163
_browser_lock = threading.Lock()
156164
_pw_instance = None
157165
_browser_instance = None
@@ -665,7 +673,7 @@ async def retrieve_memory(req: RetrieveRequest, request: Request, user: dict = D
665673
sources=[
666674
SourceRecord(
667675
domain=s.domain, content=s.content,
668-
score=round(s.score, 3), metadata=s.metadata,
676+
score=round(_safe_score(s.score), 3), metadata=s.metadata,
669677
)
670678
for s in result.sources
671679
],
@@ -717,7 +725,7 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
717725
SourceRecord(
718726
domain=s.domain,
719727
content=s.content,
720-
score=round(s.score, 3),
728+
score=round(_safe_score(s.score), 3),
721729
metadata=s.metadata,
722730
)
723731
for s in all_results
@@ -741,7 +749,7 @@ def _search_profile(pipeline: RetrievalPipeline, user_id: str) -> List[SourceRec
741749
raw = pipeline.vector_store.search_by_metadata(
742750
filters={"user_id": user_id, "domain": "profile"}, top_k=100,
743751
)
744-
return [SourceRecord(domain="profile", content=r.content, score=r.score, metadata=r.metadata) for r in raw]
752+
return [SourceRecord(domain="profile", content=r.content, score=_safe_score(r.score), metadata=r.metadata) for r in raw]
745753
except Exception as exc:
746754
logger.warning("Profile search error: %s", exc)
747755
return []
@@ -768,7 +776,7 @@ def _search_temporal(pipeline: RetrievalPipeline, query: str, user_id: str, top_
768776
parts.append(f"Time: {ev['time']}")
769777
results.append(SourceRecord(
770778
domain="temporal", content=" | ".join(parts),
771-
score=ev.get("similarity_score", 0.0), metadata=ev,
779+
score=_safe_score(ev.get("similarity_score", 0.0)), metadata=ev,
772780
))
773781
return results
774782
except Exception as exc:
@@ -783,7 +791,7 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str,
783791
filters={"user_id": user_id, "domain": "summary"},
784792
)
785793
return [
786-
SourceRecord(domain="summary", content=r.content, score=r.score, metadata={"id": r.id, **r.metadata})
794+
SourceRecord(domain="summary", content=r.content, score=_safe_score(r.score), metadata={"id": r.id, **r.metadata})
787795
for r in raw
788796
]
789797
except Exception as exc:

src/pipelines/retrieval.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import asyncio
2424
import hashlib
2525
import logging
26+
import math
2627
import time
2728
from collections import OrderedDict
2829
from typing import Any, Callable, Dict, List, Optional
@@ -323,10 +324,15 @@ async def search_raw(
323324
if not tasks:
324325
return []
325326

326-
task_results = await asyncio.gather(*tasks)
327+
task_results = await asyncio.gather(*tasks, return_exceptions=True)
327328
results = [
328-
record for domain_results in task_results for record in domain_results
329+
record
330+
for domain_results in task_results
331+
if not self._log_search_error(domain_results)
332+
for record in domain_results
329333
]
334+
for record in results:
335+
record.score = self._score_value(record.score)
330336

331337
return sorted(results, key=lambda record: record.score, reverse=True)
332338

@@ -426,7 +432,7 @@ def _search_profile(
426432
SourceRecord(
427433
domain="profile",
428434
content=r.content,
429-
score=r.score,
435+
score=self._score_value(r.score),
430436
metadata={
431437
"id": r.id,
432438
"topic": topic,
@@ -465,7 +471,7 @@ async def _search_profile_raw(
465471
SourceRecord(
466472
domain="profile",
467473
content=r.content,
468-
score=r.score,
474+
score=self._score_value(r.score),
469475
metadata={"id": r.id, **r.metadata},
470476
)
471477
)
@@ -522,7 +528,7 @@ async def _search_temporal(
522528
SourceRecord(
523529
domain="temporal",
524530
content=content,
525-
score=ev.get("similarity_score", 0.0),
531+
score=self._score_value(ev.get("similarity_score", 0.0)),
526532
metadata=ev,
527533
)
528534
)
@@ -555,7 +561,7 @@ async def _search_summary(
555561
SourceRecord(
556562
domain="summary",
557563
content=r.content,
558-
score=r.score,
564+
score=self._score_value(r.score),
559565
metadata={"id": r.id, **r.metadata},
560566
)
561567
)
@@ -602,7 +608,7 @@ async def _search_code(
602608
SourceRecord(
603609
domain="code",
604610
content=f"{prefix}{r.content}",
605-
score=r.score,
611+
score=self._score_value(r.score),
606612
metadata={"id": r.id, **metadata},
607613
)
608614
)
@@ -650,7 +656,7 @@ async def _search_snippet(
650656
SourceRecord(
651657
domain="snippet",
652658
content=content,
653-
score=r.score,
659+
score=self._score_value(r.score),
654660
metadata={"id": r.id, **r.metadata},
655661
)
656662
)
@@ -765,6 +771,19 @@ def _trim_cache(self, cache: OrderedDict, limit: int) -> None:
765771
while len(cache) > limit:
766772
cache.popitem(last=False)
767773

774+
def _log_search_error(self, domain_results: Any) -> bool:
775+
if isinstance(domain_results, Exception):
776+
logger.warning("Raw search domain failed: %s", domain_results)
777+
return True
778+
return False
779+
780+
def _score_value(self, score: Any) -> float:
781+
try:
782+
value = float(score)
783+
except (TypeError, ValueError):
784+
return 0.0
785+
return value if math.isfinite(value) else 0.0
786+
768787
def _coerce_answer(self, answer: Any) -> str:
769788
if isinstance(answer, list):
770789
parts = []
@@ -805,7 +824,8 @@ def _format_tool_results(self, records: List[SourceRecord]) -> str:
805824

806825
lines = []
807826
for i, rec in enumerate(records, 1):
808-
score_str = f" (score: {rec.score:.2f})" if rec.score > 0 else ""
827+
score = self._score_value(rec.score)
828+
score_str = f" (score: {score:.2f})" if score > 0 else ""
809829
lines.append(f"{i}. [{rec.domain}]{score_str} {rec.content}")
810830
return "\n".join(lines)
811831

tests/api/test_memory_search_routes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def search_raw(
4040
"profile": SourceRecord(
4141
domain="profile",
4242
content="work / company = XMem",
43-
score=0.7,
43+
score=None,
4444
),
4545
"code": SourceRecord(
4646
domain="code",
@@ -102,6 +102,7 @@ def test_memory_search_route_returns_raw_hits_without_answer(memory_search_app):
102102

103103
assert response.status_code == 200
104104
assert payload["data"]["total"] == 2
105+
assert payload["data"]["results"][0]["score"] == 0.0
105106
assert payload["data"]["answer"] == ""
106107
assert payload["data"]["latency"]["raw"]["count"] == 1
107108
assert pipeline.search_calls[0]["domains"] == ["profile", "summary"]

tests/integration/test_retrieval_pipeline.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,35 @@ async def fake_domain(name: str, score: float):
204204
assert [record.domain for record in results] == ["summary", "temporal", "profile"]
205205

206206

207+
@pytest.mark.asyncio
208+
async def test_raw_search_skips_failed_domains_and_normalizes_scores(
209+
vector_store, neo4j_client
210+
):
211+
model = FakeChatModel()
212+
pipeline = RetrievalPipeline(
213+
model=model, vector_store=vector_store, neo4j_client=neo4j_client
214+
)
215+
216+
async def profile_domain(*_args):
217+
return [SourceRecord(domain="profile", content="No backend score", score=None)]
218+
219+
async def summary_domain(*_args):
220+
raise RuntimeError("summary backend offline")
221+
222+
pipeline._search_profile_raw = profile_domain
223+
pipeline._search_summary = summary_domain
224+
225+
results = await pipeline.search_raw(
226+
"latency",
227+
"alice",
228+
["profile", "summary"],
229+
top_k=5,
230+
)
231+
232+
assert [(record.domain, record.score) for record in results] == [("profile", 0.0)]
233+
assert pipeline._format_tool_results(results) == "1. [profile] No backend score"
234+
235+
207236
@pytest.mark.asyncio
208237
async def test_profile_catalog_fetch_does_not_block_event_loop(
209238
vector_store, neo4j_client

0 commit comments

Comments
 (0)