Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions backend/app/api/endpoints/internal/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class InternalRetrieveRequest(BaseModel):
default=None,
description="Optional list of document IDs to filter. Only chunks from these documents will be returned.",
)
document_names: Optional[list[str]] = Field(
default=None,
description="Optional exact document names to resolve into document IDs before retrieval.",
)
route_mode: Literal["auto", "direct_injection", "rag_retrieval"] = Field(
default="auto",
description="Routing mode: auto decides in Backend, direct_injection forces all-chunks, rag_retrieval forces standard retrieval",
Expand Down Expand Up @@ -164,6 +168,22 @@ class InternalRetrieveResponse(BaseModel):
records: list[RetrieveRecord]
total: int
total_estimated_tokens: int = 0
message: Optional[str] = None


def _resolve_document_names(
db: Session,
knowledge_base_ids: list[int],
document_names: list[str],
) -> list[int]:
"""Resolve exact document names into document IDs within KB scope."""
from app.services.knowledge import KnowledgeService

return KnowledgeService.resolve_document_ids_by_names(
db=db,
knowledge_base_ids=knowledge_base_ids,
document_names=document_names,
)


@router.post(
Expand Down Expand Up @@ -193,11 +213,27 @@ async def internal_retrieve(
if request.knowledge_base_id is not None:
knowledge_base_ids = [request.knowledge_base_id]

if request.document_ids:
resolved_document_ids = request.document_ids or []
if not resolved_document_ids and request.document_names:
resolved_document_ids = _resolve_document_names(
db=db,
knowledge_base_ids=knowledge_base_ids,
document_names=request.document_names,
)
if not resolved_document_ids:
return InternalRetrieveResponse(
mode="rag_retrieval",
records=[],
total=0,
total_estimated_tokens=0,
message="Document names not found in the selected knowledge bases. Use kb_ls to inspect available documents first.",
)

if resolved_document_ids:
logger.info(
"[internal_rag] Filtering by %d documents: %s",
len(request.document_ids),
request.document_ids,
len(resolved_document_ids),
resolved_document_ids,
)

runtime_context = request.runtime_context
Expand All @@ -210,7 +246,7 @@ async def internal_retrieve(
knowledge_base_ids=knowledge_base_ids,
query=request.query,
max_results=request.max_results,
document_ids=request.document_ids,
document_ids=resolved_document_ids or None,
route_mode=request.route_mode,
user_id=persistence_context.user_id if persistence_context else None,
user_name=request.user_name,
Expand Down Expand Up @@ -277,8 +313,8 @@ async def internal_retrieve(
available_injection_tokens,
request.query[:50],
(
f", filtered by {len(request.document_ids)} docs"
if request.document_ids
f", filtered by {len(resolved_document_ids)} docs"
if resolved_document_ids
else ""
),
)
Expand Down Expand Up @@ -328,6 +364,7 @@ async def internal_retrieve(
],
total=len(records),
total_estimated_tokens=total_estimated_tokens,
message=result.get("message"),
)

except ValueError as e:
Expand Down
22 changes: 22 additions & 0 deletions backend/app/services/chat/preprocessing/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,13 +1319,17 @@ def _build_kb_meta_prompt(
format_restricted_kb_meta_prompt,
select_kb_summary_text,
)
from app.services.knowledge import KnowledgeService
from app.services.knowledge.task_knowledge_base_service import (
task_knowledge_base_service,
)

kb_map = task_knowledge_base_service.get_knowledge_bases_by_ids(
db, knowledge_base_ids
)
kb_prompt_stats = KnowledgeService.get_document_prompt_stats(
db, knowledge_base_ids
)

kb_meta_list: list[dict[str, Any]] = []
for kb_id in knowledge_base_ids:
Expand All @@ -1335,6 +1339,10 @@ def _build_kb_meta_prompt(
{
"kb_id": kb_id,
"kb_name": "Unknown",
"search_available": False,
"total_document_count": 0,
"searchable_document_count": 0,
"spreadsheet_document_count": 0,
"summary_text": "",
"topics": [],
}
Expand Down Expand Up @@ -1365,10 +1373,24 @@ def _build_kb_meta_prompt(
exc_info=True,
)

stats = kb_prompt_stats.get(
kb_id,
{
"total_document_count": 0,
"searchable_document_count": 0,
"spreadsheet_document_count": 0,
},
)
retrieval_config = kb_spec.get("retrievalConfig") or {}

kb_meta_list.append(
{
"kb_id": kb_id,
"kb_name": kb_name,
"search_available": bool(retrieval_config.get("retriever_name")),
"total_document_count": stats["total_document_count"],
"searchable_document_count": stats["searchable_document_count"],
"spreadsheet_document_count": stats["spreadsheet_document_count"],
"summary_text": summary_text,
"topics": topics,
}
Expand Down
19 changes: 18 additions & 1 deletion backend/app/services/chat/preprocessing/kb_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,24 @@ def format_kb_meta_prompt(kb_meta_list: list[dict[str, Any]]) -> str:
kb_meta.get("kb_name", "Unknown"), "Unknown"
)
kb_id = sanitize_prompt_identifier(kb_meta.get("kb_id", "N/A"), "N/A")
kb_lines.append(f"- KB Name: {kb_name}, KB ID: {kb_id}")
search_available = (
"available" if kb_meta.get("search_available") else "unavailable"
)
total_document_count = int(kb_meta.get("total_document_count", 0) or 0)
searchable_document_count = int(
kb_meta.get("searchable_document_count", 0) or 0
)
spreadsheet_document_count = int(
kb_meta.get("spreadsheet_document_count", 0) or 0
)

kb_lines.append(
f"- KB Name: {kb_name}, KB ID: {kb_id}, "
f"Search: {search_available}, "
f"Total Docs: {total_document_count}, "
f"Searchable Docs: {searchable_document_count}, "
f"Spreadsheets: {spreadsheet_document_count}"
)

summary_text = kb_meta.get("summary_text") or ""
topics = kb_meta.get("topics") or []
Expand Down
76 changes: 75 additions & 1 deletion backend/app/services/knowledge/knowledge_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass
from typing import Optional

from sqlalchemy import and_, func
from sqlalchemy import and_, case, func
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified

Expand Down Expand Up @@ -845,6 +845,80 @@ def get_active_document_counts(

return {kb_id: count for kb_id, count in results}

@staticmethod
def get_document_prompt_stats(
db: Session,
knowledge_base_ids: list[int],
) -> dict[int, dict[str, int]]:
"""Get prompt-oriented document stats for multiple knowledge bases."""
if not knowledge_base_ids:
return {}

spreadsheet_exts = ["csv", "xls", "xlsx"]

results = (
db.query(
KnowledgeDocument.kind_id,
func.count(KnowledgeDocument.id).label("total_count"),
func.sum(case((KnowledgeDocument.is_active == True, 1), else_=0)).label(
"searchable_count"
),
func.sum(
case(
(
and_(
KnowledgeDocument.is_active == True,
func.lower(KnowledgeDocument.file_extension).in_(
spreadsheet_exts
),
),
1,
),
else_=0,
)
).label("spreadsheet_count"),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)
.filter(KnowledgeDocument.kind_id.in_(knowledge_base_ids))
.group_by(KnowledgeDocument.kind_id)
.all()
)

return {
kb_id: {
"total_document_count": int(total_count or 0),
"searchable_document_count": int(searchable_count or 0),
"spreadsheet_document_count": int(spreadsheet_count or 0),
}
for kb_id, total_count, searchable_count, spreadsheet_count in results
}

@staticmethod
def resolve_document_ids_by_names(
db: Session,
knowledge_base_ids: list[int],
document_names: list[str],
) -> list[int]:
"""Resolve exact document names within the provided knowledge-base scope."""
if not knowledge_base_ids or not document_names:
return []

normalized_names = [
name.strip() for name in document_names if name and name.strip()
]
if not normalized_names:
return []

rows = (
db.query(KnowledgeDocument.id)
.filter(
KnowledgeDocument.kind_id.in_(knowledge_base_ids),
KnowledgeDocument.name.in_(normalized_names),
KnowledgeDocument.is_active == True,
)
.all()
)
return [row.id for row in rows]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@staticmethod
def get_active_document_text_length_stats(
db: Session,
Expand Down
51 changes: 51 additions & 0 deletions backend/tests/api/endpoints/internal/test_rag_retrieve_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,54 @@ def test_internal_retrieve_keeps_user_subtask_id_out_of_gateway(test_client):
assert response.status_code == 200
mock_query.assert_awaited_once_with(ANY, db=ANY)
mock_persist.assert_called_once()


def test_internal_retrieve_resolves_document_names_before_query(test_client):
with (
patch(
"app.api.endpoints.internal.rag._resolve_document_names",
return_value=[101, 102],
) as mock_resolve,
patch(
"app.api.endpoints.internal.rag.LocalRagGateway.query",
new_callable=AsyncMock,
return_value={
"mode": "rag_retrieval",
"records": [],
"total": 0,
"total_estimated_tokens": 0,
},
) as mock_query,
):
response = test_client.post(
"/api/internal/rag/retrieve",
json={
"query": "release checklist",
"knowledge_base_ids": [12],
"document_names": ["release.md"],
},
)

assert response.status_code == 200
mock_resolve.assert_called_once()
mock_query.assert_awaited_once()


def test_internal_retrieve_returns_error_when_document_names_not_found(test_client):
with patch(
"app.api.endpoints.internal.rag._resolve_document_names",
return_value=[],
):
response = test_client.post(
"/api/internal/rag/retrieve",
json={
"query": "release checklist",
"knowledge_base_ids": [12],
"document_names": ["missing.md"],
},
)

assert response.status_code == 200
assert response.json()["mode"] == "rag_retrieval"
assert response.json()["records"] == []
assert response.json()["message"].startswith("Document names not found")
Loading
Loading