|
16 | 16 |
|
17 | 17 | from src.api.dependencies import ( |
18 | 18 | enforce_rate_limit, |
| 19 | + get_code_pipeline, |
19 | 20 | get_ingest_pipeline, |
20 | 21 | get_retrieval_pipeline, |
21 | 22 | require_api_key, |
@@ -689,16 +690,77 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen |
689 | 690 | user_id = user.get("username") or user.get("name") or user["id"] |
690 | 691 |
|
691 | 692 | try: |
692 | | - all_results: List[SourceRecord] = [] |
693 | | - |
694 | | - if "profile" in req.domains: |
695 | | - all_results.extend(_search_profile(pipeline, user_id)) |
696 | | - if "temporal" in req.domains: |
697 | | - all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k)) |
698 | | - if "summary" in req.domains: |
699 | | - all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k)) |
| 693 | + if "code" in req.domains and not req.org_id: |
| 694 | + elapsed = round((time.perf_counter() - start) * 1000, 2) |
| 695 | + return _error(request, "org_id is required when searching the code domain.", 400, elapsed) |
| 696 | + |
| 697 | + memory_domains = [domain for domain in req.domains if domain != "code"] |
| 698 | + result = await pipeline.raw_search( |
| 699 | + query=req.query, |
| 700 | + user_id=user_id, |
| 701 | + domains=memory_domains, |
| 702 | + top_k=req.top_k, |
| 703 | + include_answer=False, |
| 704 | + ) |
| 705 | + records = list(result.sources) |
| 706 | + |
| 707 | + if "code" in req.domains: |
| 708 | + code_pipeline = get_code_pipeline(org_id=req.org_id or "", repo=req.repo) |
| 709 | + code_results = await asyncio.gather( |
| 710 | + code_pipeline._execute_tool( |
| 711 | + tool_name="search_symbols", |
| 712 | + tool_args={"query": req.query, "repo": req.repo}, |
| 713 | + repo=req.repo, |
| 714 | + top_k=req.top_k, |
| 715 | + user_id=user_id, |
| 716 | + ), |
| 717 | + code_pipeline._execute_tool( |
| 718 | + tool_name="search_files", |
| 719 | + tool_args={"query": req.query, "repo": req.repo}, |
| 720 | + repo=req.repo, |
| 721 | + top_k=req.top_k, |
| 722 | + user_id=user_id, |
| 723 | + ), |
| 724 | + return_exceptions=True, |
| 725 | + ) |
| 726 | + for code_records in code_results: |
| 727 | + if isinstance(code_records, Exception): |
| 728 | + logger.warning("Code search subquery failed: %s", code_records) |
| 729 | + continue |
| 730 | + records.extend(code_records) |
| 731 | + |
| 732 | + records = sorted(records, key=lambda s: s.score or 0.0, reverse=True) |
| 733 | + |
| 734 | + answer = "" |
| 735 | + if req.answer: |
| 736 | + answer = await pipeline.answer_from_sources(query=req.query, sources=records) |
| 737 | + pipeline._record_latency( |
| 738 | + "raw_search_answer", |
| 739 | + (time.perf_counter() - start) * 1000, |
| 740 | + ) |
| 741 | + elif "code" in req.domains: |
| 742 | + pipeline._record_latency( |
| 743 | + "raw_search_code", |
| 744 | + (time.perf_counter() - start) * 1000, |
| 745 | + ) |
700 | 746 |
|
701 | | - data = SearchResponse(results=all_results, total=len(all_results)) |
| 747 | + confidence = pipeline.confidence_from_sources(records) |
| 748 | + data = SearchResponse( |
| 749 | + results=[ |
| 750 | + SourceRecord( |
| 751 | + domain=s.domain, |
| 752 | + content=s.content, |
| 753 | + score=round(s.score, 3) if s.score is not None else 0.0, |
| 754 | + metadata=s.metadata, |
| 755 | + ) |
| 756 | + for s in records |
| 757 | + ], |
| 758 | + total=len(records), |
| 759 | + answer=answer, |
| 760 | + model=_model_name(pipeline.model) if req.answer else "", |
| 761 | + confidence=confidence, |
| 762 | + latency=pipeline.latency_snapshot(), |
| 763 | + ) |
702 | 764 | elapsed = round((time.perf_counter() - start) * 1000, 2) |
703 | 765 | return _wrap(request, data, elapsed) |
704 | 766 |
|
|
0 commit comments