diff --git a/context_chat_backend/chain/context.py b/context_chat_backend/chain/context.py index 2e0219d..adbac2d 100644 --- a/context_chat_backend/chain/context.py +++ b/context_chat_backend/chain/context.py @@ -6,8 +6,9 @@ from langchain.schema import Document +from ..dyn_loader import VectorDBLoader from ..vectordb.base import BaseVectorDB -from .types import ContextException, ScopeType +from .types import ContextException, ScopeType, SearchResult logger = logging.getLogger('ccb.chain') @@ -39,3 +40,53 @@ def get_context_chunks(context_docs: list[Document]) -> list[str]: context_chunks.append(doc.page_content) return context_chunks + + +def do_doc_search( + user_id: str, + query: str, + vectordb_loader: VectorDBLoader, + ctx_limit: int = 20, + scope_type: ScopeType | None = None, + scope_list: list[str] | None = None, +) -> list[SearchResult]: + """ + Raises + ------ + ContextException + If the scope type is provided but the scope list is empty or not provided + """ + db = vectordb_loader.load() + augmented_limit = ctx_limit * 2 # to account for duplicate sources + docs = get_context_docs(user_id, query, db, augmented_limit, scope_type, scope_list) + if len(docs) == 0: + logger.warning('No documents retrieved, please index a few documents first') + return [] + + sources_cache = {} + results: list[SearchResult] = [] + for doc in docs: + source_id = doc.metadata.get('source') + if not source_id: + logger.warning('Document without source id encountered in doc search, skipping', extra={ + 'doc': doc, + }) + continue + if source_id in sources_cache: + continue + if len(results) >= ctx_limit: + break + + sources_cache[source_id] = None + results.append(SearchResult( + source_id=source_id, + title=doc.metadata.get('title', ''), + )) + + logger.debug('do_doc_search', extra={ + 'len(docs)': len(docs), + 'len(results)': len(results), + 'scope_type': scope_type, + 'scope_list': scope_list, + }) + return results diff --git a/context_chat_backend/chain/types.py b/context_chat_backend/chain/types.py index 4d9b2ab..b006ad1 100644 --- a/context_chat_backend/chain/types.py +++ b/context_chat_backend/chain/types.py @@ -36,3 +36,9 @@ class ContextException(Exception): class LLMOutput(TypedDict): output: str sources: list[str] + # todo: add "titles" field + + +class SearchResult(TypedDict): + source_id: str + title: str diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index a9970d1..19c236d 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -4,7 +4,7 @@ # # isort: off -from .chain.types import ContextException, LLMOutput, ScopeType +from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult from .types import LoaderException, EmbeddingException from .vectordb.types import DbException, SafeDbException, UpdateAccessOp # isort: on @@ -26,6 +26,7 @@ from nc_py_api.ex_app import persistent_storage, set_handlers from pydantic import BaseModel, ValidationInfo, field_validator +from .chain.context import do_doc_search from .chain.ingest.injest import embed_sources from .chain.one_shot import process_context_query, process_query from .config_parser import get_config @@ -315,12 +316,14 @@ def _(userId: str = Body(embed=True)): return JSONResponse('User deleted') + @app.post('/countIndexedDocuments') @enabled_guard(app) def _(): counts = exec_in_proc(target=count_documents_by_provider, args=(vectordb_loader,)) return JSONResponse(counts) + @app.put('/loadSources') @enabled_guard(app) def _(sources: list[UploadFile]): @@ -467,3 +470,17 @@ def _(query: Query) -> LLMOutput: with llm_lock: return execute_query(query, in_proc=False) + + +@app.post('/docSearch') +@enabled_guard(app) +def _(query: Query) -> list[SearchResult]: + # useContext from Query is not used here + return exec_in_proc(target=do_doc_search, args=( + query.userId, + query.query, + vectordb_loader, + query.ctxLimit, + query.scopeType, + query.scopeList, + ))