Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.

Commit bb08f07

Browse files
Add raw memory search fast path
1 parent c36fefe commit bb08f07

4 files changed

Lines changed: 450 additions & 37 deletions

File tree

src/api/routes/memory.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from src.api.dependencies import (
1818
enforce_rate_limit,
19+
get_code_pipeline,
1920
get_ingest_pipeline,
2021
get_retrieval_pipeline,
2122
require_api_key,
@@ -689,16 +690,77 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen
689690
user_id = user.get("username") or user.get("name") or user["id"]
690691

691692
try:
692-
all_results: List[SourceRecord] = []
693-
694-
if "profile" in req.domains:
695-
all_results.extend(_search_profile(pipeline, user_id))
696-
if "temporal" in req.domains:
697-
all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k))
698-
if "summary" in req.domains:
699-
all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k))
693+
if "code" in req.domains and not req.org_id:
694+
elapsed = round((time.perf_counter() - start) * 1000, 2)
695+
return _error(request, "org_id is required when searching the code domain.", 400, elapsed)
696+
697+
memory_domains = [domain for domain in req.domains if domain != "code"]
698+
result = await pipeline.raw_search(
699+
query=req.query,
700+
user_id=user_id,
701+
domains=memory_domains,
702+
top_k=req.top_k,
703+
include_answer=False,
704+
)
705+
records = list(result.sources)
706+
707+
if "code" in req.domains:
708+
code_pipeline = get_code_pipeline(org_id=req.org_id or "", repo=req.repo)
709+
code_results = await asyncio.gather(
710+
code_pipeline._execute_tool(
711+
tool_name="search_symbols",
712+
tool_args={"query": req.query, "repo": req.repo},
713+
repo=req.repo,
714+
top_k=req.top_k,
715+
user_id=user_id,
716+
),
717+
code_pipeline._execute_tool(
718+
tool_name="search_files",
719+
tool_args={"query": req.query, "repo": req.repo},
720+
repo=req.repo,
721+
top_k=req.top_k,
722+
user_id=user_id,
723+
),
724+
return_exceptions=True,
725+
)
726+
for code_records in code_results:
727+
if isinstance(code_records, Exception):
728+
logger.warning("Code search subquery failed: %s", code_records)
729+
continue
730+
records.extend(code_records)
731+
732+
records = sorted(records, key=lambda s: s.score or 0.0, reverse=True)
733+
734+
answer = ""
735+
if req.answer:
736+
answer = await pipeline.answer_from_sources(query=req.query, sources=records)
737+
pipeline._record_latency(
738+
"raw_search_answer",
739+
(time.perf_counter() - start) * 1000,
740+
)
741+
elif "code" in req.domains:
742+
pipeline._record_latency(
743+
"raw_search_code",
744+
(time.perf_counter() - start) * 1000,
745+
)
700746

701-
data = SearchResponse(results=all_results, total=len(all_results))
747+
confidence = pipeline.confidence_from_sources(records)
748+
data = SearchResponse(
749+
results=[
750+
SourceRecord(
751+
domain=s.domain,
752+
content=s.content,
753+
score=round(s.score, 3) if s.score is not None else 0.0,
754+
metadata=s.metadata,
755+
)
756+
for s in records
757+
],
758+
total=len(records),
759+
answer=answer,
760+
model=_model_name(pipeline.model) if req.answer else "",
761+
confidence=confidence,
762+
latency=pipeline.latency_snapshot(),
763+
)
702764
elapsed = round((time.perf_counter() - start) * 1000, 2)
703765
return _wrap(request, data, elapsed)
704766

src/api/schemas.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,24 +159,48 @@ class SearchRequest(BaseModel):
159159
..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$",
160160
)
161161
domains: List[str] = Field(
162-
default=["profile", "temporal", "summary"],
162+
default=["profile", "temporal", "summary", "snippet"],
163163
description="Which memory domains to search",
164164
)
165165
top_k: int = Field(default=10, ge=1, le=100)
166+
answer: bool = Field(
167+
default=False,
168+
description="When true, synthesize an LLM answer from the raw hits.",
169+
)
170+
org_id: Optional[str] = Field(
171+
default=None,
172+
min_length=1,
173+
max_length=256,
174+
description="Required when including the code domain.",
175+
)
176+
repo: str = Field(
177+
default="",
178+
max_length=256,
179+
description="Optional repository scope for code search.",
180+
)
166181

167182
@field_validator("domains")
168183
@classmethod
169184
def validate_domains(cls, v: List[str]) -> List[str]:
170-
allowed = {"profile", "temporal", "summary"}
185+
allowed = {"profile", "temporal", "summary", "snippet", "code"}
171186
for d in v:
172187
if d not in allowed:
173188
raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}")
174189
return v
175190

191+
@field_validator("query")
192+
@classmethod
193+
def strip_search_query(cls, v: str) -> str:
194+
return v.strip()
195+
176196

177197
class SearchResponse(BaseModel):
178198
results: List[SourceRecord] = Field(default_factory=list)
179199
total: int = 0
200+
answer: str = ""
201+
model: str = ""
202+
confidence: float = 0.0
203+
latency: Dict[str, Dict[str, float]] = Field(default_factory=dict)
180204

181205

182206
# ── Scrape (extract from shared chat links) ────────────────────────────────

0 commit comments

Comments
 (0)