Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.

Commit 44c885e

Browse files
committed
Address raw search review feedback
1 parent 2863756 commit 44c885e

2 files changed

Lines changed: 59 additions & 21 deletions

File tree

src/api/routes/memory.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
_ingest_semaphore = asyncio.Semaphore(5)
5555
_latency_samples: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=200))
56+
_latency_lock = threading.Lock()
5657

5758
router = APIRouter(
5859
prefix="/v1/memory",
@@ -112,7 +113,8 @@ def _error(request: Request, detail: str, code: int, elapsed_ms: float = 0) -> J
112113

113114

114115
def _record_latency(mode: str, elapsed_ms: float) -> None:
115-
_latency_samples[mode].append(elapsed_ms)
116+
with _latency_lock:
117+
_latency_samples[mode].append(elapsed_ms)
116118

117119

118120
def _percentile(sorted_values: List[float], percentile: float) -> float:
@@ -123,8 +125,11 @@ def _percentile(sorted_values: List[float], percentile: float) -> float:
123125

124126

125127
def _latency_stats() -> Dict[str, Dict[str, float]]:
128+
with _latency_lock:
129+
snapshot = {mode: list(samples) for mode, samples in _latency_samples.items()}
130+
126131
stats: Dict[str, Dict[str, float]] = {}
127-
for mode, samples in _latency_samples.items():
132+
for mode, samples in snapshot.items():
128133
values = sorted(samples)
129134
stats[mode] = {
130135
"count": len(values),
@@ -135,11 +140,14 @@ def _latency_stats() -> Dict[str, Dict[str, float]]:
135140
return stats
136141

137142

138-
async def _timed(mode: str, func, *args, **kwargs):
143+
async def _timed(mode: str, func, *args, threaded: bool = False, **kwargs):
139144
start = time.perf_counter()
140-
result = func(*args, **kwargs)
141-
if hasattr(result, "__await__"):
142-
result = await result
145+
if threaded:
146+
result = await asyncio.to_thread(func, *args, **kwargs)
147+
else:
148+
result = func(*args, **kwargs)
149+
if hasattr(result, "__await__"):
150+
result = await result
143151
elapsed_ms = round((time.perf_counter() - start) * 1000, 2)
144152
_record_latency(mode, elapsed_ms)
145153
return result, elapsed_ms
@@ -727,27 +735,39 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
727735
all_results: List[SourceRecord] = []
728736
latency_ms: Dict[str, float] = {}
729737
plan = pipeline.raw_retrieval_plan(req.domains, answer=req.answer)
738+
raw_tasks = []
730739

731740
if "profile" in plan:
732-
results, elapsed = await _timed("profile", _search_profile, pipeline, user_id)
733-
latency_ms["profile"] = elapsed
734-
all_results.extend(results)
741+
raw_tasks.append((
742+
"profile",
743+
_timed("profile", _search_profile, pipeline, user_id, threaded=True),
744+
))
735745
if "temporal" in plan:
736-
results, elapsed = await _timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k)
737-
latency_ms["temporal"] = elapsed
738-
all_results.extend(results)
746+
raw_tasks.append((
747+
"temporal",
748+
_timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k, threaded=True),
749+
))
739750
if "summary" in plan:
740-
results, elapsed = await _timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k)
741-
latency_ms["summary"] = elapsed
742-
all_results.extend(results)
751+
raw_tasks.append((
752+
"summary",
753+
_timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k),
754+
))
743755
if "snippet" in plan:
744-
results, elapsed = await _timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k)
745-
latency_ms["snippet"] = elapsed
746-
all_results.extend(results)
756+
raw_tasks.append((
757+
"snippet",
758+
_timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k),
759+
))
747760
if "code" in plan:
748-
results, elapsed = await _timed("code", _search_code, pipeline, req.query, user_id, req.top_k)
749-
latency_ms["code"] = elapsed
750-
all_results.extend(results)
761+
raw_tasks.append((
762+
"code",
763+
_timed("code", _search_code, pipeline, req.query, user_id, req.top_k),
764+
))
765+
766+
if raw_tasks:
767+
raw_results = await asyncio.gather(*(task for _, task in raw_tasks))
768+
for (domain, _), (results, elapsed) in zip(raw_tasks, raw_results):
769+
latency_ms[domain] = elapsed
770+
all_results.extend(results)
751771

752772
all_results.sort(key=lambda record: record.score, reverse=True)
753773

src/pipelines/retrieval.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], list]] = {}
138138
self._raw_retrieval_plan_cache: Dict[tuple[tuple[str, ...], bool], tuple[str, ...]] = {}
139139
self._cache_ttl_seconds = 60.0
140+
self._profile_catalog_cache_max_users = 256
140141

141142
logger.info("RetrievalPipeline initialized")
142143

@@ -499,8 +500,11 @@ def _fetch_profile_catalog(self, user_id: str):
499500
raw_results — the full SearchResult list, cached for _search_profile
500501
"""
501502
now = time.monotonic()
503+
self._prune_profile_catalog_cache(now)
504+
502505
cached = self._profile_catalog_cache.get(user_id)
503506
if cached and now - cached[0] < self._cache_ttl_seconds:
507+
self._profile_catalog_cache[user_id] = (now, cached[1], cached[2])
504508
return cached[1], cached[2]
505509

506510
try:
@@ -536,6 +540,20 @@ def _fetch_profile_catalog(self, user_id: str):
536540
self._profile_catalog_cache[user_id] = (now, catalog, results)
537541
return catalog, results
538542

543+
def _prune_profile_catalog_cache(self, now: float) -> None:
544+
"""Bound profile catalog cache by TTL and number of cached users."""
545+
expired_user_ids = [
546+
cached_user_id
547+
for cached_user_id, (cached_at, _, _) in self._profile_catalog_cache.items()
548+
if now - cached_at >= self._cache_ttl_seconds
549+
]
550+
for cached_user_id in expired_user_ids:
551+
self._profile_catalog_cache.pop(cached_user_id, None)
552+
553+
while len(self._profile_catalog_cache) >= self._profile_catalog_cache_max_users:
554+
oldest_user_id = next(iter(self._profile_catalog_cache))
555+
self._profile_catalog_cache.pop(oldest_user_id, None)
556+
539557
def raw_retrieval_plan(self, domains: List[str], answer: bool = False) -> tuple[str, ...]:
540558
"""Return a cached deterministic raw-search plan for the requested domains."""
541559
ordered_allowed = ("profile", "temporal", "summary", "snippet", "code")

0 commit comments

Comments
 (0)