diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 10e2d61b..d30073ab 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -89,7 +89,7 @@ jobs: POSTGRES_USER: root POSTGRES_PASSWORD: rootpassword POSTGRES_DB: nextcloud - options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 + options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 --name postgres --hostname postgres steps: - name: Checkout server @@ -113,6 +113,8 @@ jobs: repository: nextcloud/context_chat path: apps/context_chat persist-credentials: false + # todo: remove later + ref: feat/reverse-content-flow - name: Checkout backend uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 @@ -167,6 +169,10 @@ jobs: cd .. rm -rf documentation + - name: Run files scan + run: | + ./occ files:scan --all + - name: Setup python 3.11 uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5 with: @@ -195,28 +201,91 @@ jobs: timeout 10 ./occ app_api:daemon:register --net host manual_install "Manual Install" manual-install http localhost http://localhost:8080 timeout 120 ./occ app_api:app:register context_chat_backend manual_install --json-info "{\"appid\":\"context_chat_backend\",\"name\":\"Context Chat Backend\",\"daemon_config_name\":\"manual_install\",\"version\":\"${{ fromJson(steps.appinfo.outputs.result).version }}\",\"secret\":\"12345\",\"port\":10034,\"scopes\":[],\"system_app\":0}" --force-scopes --wait-finish ls -la context_chat_backend/persistent_storage/* - sleep 30 # Wait for the em server to get ready - - name: Scan files, baseline - run: | - ./occ files:scan admin - ./occ context_chat:scan admin -m text/plain - - - name: Check python memory usage + - name: Initial memory usage check run: | ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt - - name: Scan files + - name: Run cron jobs run: | - ./occ files:scan admin - ./occ context_chat:scan admin -m text/markdown & - ./occ context_chat:scan admin -m text/x-rst + # every 10 seconds indefinitely + while true; do + php cron.php + sleep 10 + done & + sleep 30 + # list all the bg jobs + ./occ background-job:list + + - name: Initial dump of DB with context_chat_queue populated + run: | + docker exec postgres pg_dump nextcloud > /tmp/0_pgdump_nextcloud - - name: Check python memory usage + - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files run: | - ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem - ps -p $(cat pid.txt) -o %mem --no-headers > after_scan_mem.txt + success=0 + echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files" + for i in {1..90}; do + echo "Checking stats, attempt $i..." + + stats_err=$(mktemp) + stats=$(timeout 5 ./occ context_chat:stats --json 2>"$stats_err") + stats_exit=$? + echo "Stats output:" + echo "$stats" + if [ -s "$stats_err" ]; then + echo "Stderr:" + cat "$stats_err" + fi + echo "---" + rm -f "$stats_err" + + # Check for critical errors in output + if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then + echo "Backend connection error detected (exit=$stats_exit), retrying..." + sleep 10 + continue + fi + + # Extract total eligible files + total_eligible_files=$(echo "$stats" | jq '.eligible_files_count' || echo "") + + # Extract indexed documents count (files__default) + indexed_count=$(echo "$stats" | jq '.vectordb_document_counts.files__default' || echo "") + + echo "Total eligible files: $total_eligible_files" + echo "Indexed documents (files__default): $indexed_count" + + diff=$((total_eligible_files - indexed_count)) + threshold=$((total_eligible_files * 3 / 100)) + + # Check if difference is within tolerance + if [ $diff -le $threshold ]; then + echo "Indexing within 3% tolerance (diff=$diff, threshold=$threshold)" + success=1 + break + else + progress=$((diff * 100 / total_eligible_files)) + echo "Outside 3% tolerance: diff=$diff (${progress}%), threshold=$threshold" + fi + + # Check if backend is still alive + ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0") + if [ "$ccb_alive" -eq 0 ]; then + echo "Error: Context Chat Backend process is not running. Exiting." + exit 1 + fi + + sleep 10 + done + + echo "::endgroup::" + + if [ $success -ne 1 ]; then + echo "Max attempts reached" + exit 1 + fi - name: Run the prompts run: | @@ -250,18 +319,9 @@ jobs: echo "Memory usage during scan is stable. No memory leak detected." fi - - name: Compare memory usage and detect leak + - name: Final dump of DB with vectordb populated run: | - initial_mem=$(cat after_scan_mem.txt | tr -d ' ') - final_mem=$(cat after_prompt_mem.txt | tr -d ' ') - echo "Initial Memory Usage: $initial_mem%" - echo "Memory Usage after prompt: $final_mem%" - - if (( $(echo "$final_mem > $initial_mem" | bc -l) )); then - echo "Memory usage has increased during prompt. Possible memory leak detected!" - else - echo "Memory usage during prompt is stable. No memory leak detected." - fi + docker exec postgres pg_dump nextcloud > /tmp/1_pgdump_nextcloud - name: Show server logs if: always() @@ -298,6 +358,19 @@ jobs: run: | tail -v -n +1 context_chat_backend/persistent_storage/logs/em_server.log* || echo "No logs in logs directory" + - name: Upload database dumps + uses: actions/upload-artifact@v4 + with: + name: database-dumps-${{ matrix.server-versions }}-php@${{ matrix.php-versions }} + path: | + /tmp/0_pgdump_nextcloud + /tmp/1_pgdump_nextcloud + + - name: Final stats log + run: | + ./occ context_chat:stats + ./occ context_chat:stats --json + summary: permissions: contents: none diff --git a/appinfo/info.xml b/appinfo/info.xml index 9760cd29..30194baa 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -82,5 +82,19 @@ Setup background job workers as described here: https://docs.nextcloud.com/serve Password to be used for authenticating requests to the OpenAI-compatible endpoint set in CC_EM_BASE_URL. + + + rp + Request Processing Mode + APP_ROLE=rp + true + + + indexing + Indexing Mode + APP_ROLE=indexing + false + + diff --git a/config.cpu.yaml b/config.cpu.yaml index 1512ea07..6ceac915 100644 --- a/config.cpu.yaml +++ b/config.cpu.yaml @@ -7,7 +7,10 @@ verify_ssl: true use_colors: true uvicorn_workers: 1 embedding_chunk_size: 2000 -doc_parser_worker_limit: 10 +doc_indexing_batch_size: 32 # theoretical max RAM usage: 32 * 100 MiB +actions_batch_size: 512 +file_parsing_cpu_count: -1 # divides the batch into these many chunks, -1 = auto +concurrent_file_fetches: 10 # maximum number of files to fetch concurrently to not overload the NC server vectordb: @@ -43,6 +46,9 @@ embedding: llm: nc_texttotext: + # template: + # n_ctx: + # max_tokens: llama: # all options: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.llamacpp.LlamaCpp.html @@ -52,14 +58,12 @@ llm: max_tokens: 4096 template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n" no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n" - end_separator: "<|im_end|>" ctransformer: # all options: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.ctransformers.CTransformers.html model: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n" no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n" - end_separator: "<|im_end|>" config: context_length: 8192 max_new_tokens: 4096 diff --git a/config.gpu.yaml b/config.gpu.yaml index fc3acaf2..a12fd1be 100644 --- a/config.gpu.yaml +++ b/config.gpu.yaml @@ -7,7 +7,10 @@ verify_ssl: true use_colors: true uvicorn_workers: 1 embedding_chunk_size: 2000 -doc_parser_worker_limit: 10 +doc_indexing_batch_size: 32 # theoretical max RAM usage: 32 * 100 MiB +actions_batch_size: 512 +file_parsing_cpu_count: -1 # divides the batch into these many chunks, -1 = auto +concurrent_file_fetches: 10 # maximum number of files to fetch concurrently to not overload the NC server vectordb: @@ -44,6 +47,9 @@ embedding: llm: nc_texttotext: + # template: + # n_ctx: + # max_tokens: llama: # all options: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.llamacpp.LlamaCpp.html @@ -53,7 +59,6 @@ llm: max_tokens: 4096 template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n" no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n" - end_separator: "<|im_end|>" n_gpu_layers: -1 model_kwargs: device: cuda @@ -63,7 +68,6 @@ llm: model: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n" no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n" - end_separator: "<|im_end|>" config: context_length: 8192 max_new_tokens: 4096 diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py index efb81b6d..832c8331 100644 --- a/context_chat_backend/chain/ingest/doc_loader.py +++ b/context_chat_backend/chain/ingest/doc_loader.py @@ -3,15 +3,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # -import logging import re import tempfile from collections.abc import Callable -from typing import BinaryIO +from io import BytesIO import docx2txt from epub2txt import epub2txt -from fastapi import UploadFile from langchain_unstructured import UnstructuredLoader from odfdo import Document from pandas import read_csv, read_excel @@ -19,9 +17,10 @@ from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError from striprtf import striprtf -logger = logging.getLogger('ccb.doc_loader') +from ...types import IndexingException, SourceItem -def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str: + +def _temp_file_wrapper(file: BytesIO, loader: Callable, sep: str = '\n') -> str: raw_bytes = file.read() with tempfile.NamedTemporaryFile(mode='wb') as tmp: tmp.write(raw_bytes) @@ -35,49 +34,49 @@ def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str # -- LOADERS -- # -def _load_pdf(file: BinaryIO) -> str: +def _load_pdf(file: BytesIO) -> str: pdf_reader = PdfReader(file) return '\n\n'.join([page.extract_text().strip() for page in pdf_reader.pages]) -def _load_csv(file: BinaryIO) -> str: +def _load_csv(file: BytesIO) -> str: return read_csv(file).to_string(header=False, na_rep='') -def _load_epub(file: BinaryIO) -> str: +def _load_epub(file: BytesIO) -> str: return _temp_file_wrapper(file, epub2txt).strip() -def _load_docx(file: BinaryIO) -> str: +def _load_docx(file: BytesIO) -> str: return docx2txt.process(file).strip() -def _load_odt(file: BinaryIO) -> str: +def _load_odt(file: BytesIO) -> str: return _temp_file_wrapper(file, lambda fp: Document(fp).get_formatted_text()).strip() -def _load_ppt_x(file: BinaryIO) -> str: +def _load_ppt_x(file: BytesIO) -> str: return _temp_file_wrapper(file, lambda fp: UnstructuredLoader(fp).load()).strip() -def _load_rtf(file: BinaryIO) -> str: +def _load_rtf(file: BytesIO) -> str: return striprtf.rtf_to_text(file.read().decode('utf-8', 'ignore')).strip() -def _load_xml(file: BinaryIO) -> str: +def _load_xml(file: BytesIO) -> str: data = file.read().decode('utf-8', 'ignore') data = re.sub(r'', '', data) return data.strip() -def _load_xlsx(file: BinaryIO) -> str: +def _load_xlsx(file: BytesIO) -> str: return read_excel(file, na_filter=False).to_string(header=False, na_rep='') -def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None: +def _load_email(file: BytesIO, ext: str = 'eml') -> str: # NOTE: msg format is not tested if ext not in ['eml', 'msg']: - return None + raise IndexingException(f'Unsupported email format: {ext}') # TODO: implement attachment partitioner using unstructured.partition.partition_{email,msg} # since langchain does not pass through the attachment_partitioner kwarg @@ -115,30 +114,36 @@ def attachment_partitioner( } -def decode_source(source: UploadFile) -> str | None: +def decode_source(source: SourceItem) -> str: + ''' + Raises + ------ + IndexingException + ''' + + io_obj: BytesIO | None = None try: # .pot files are powerpoint templates but also plain text files, # so we skip them to prevent decoding errors - if source.headers['title'].endswith('.pot'): - return None - - mimetype = source.headers['type'] - if mimetype is None: - return None - - if _loader_map.get(mimetype): - result = _loader_map[mimetype](source.file) - source.file.close() - return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore') - - result = source.file.read().decode('utf-8', 'ignore') - source.file.close() - return result - except PdfFileNotDecryptedError: - logger.warning(f'PDF file ({source.filename}) is encrypted and cannot be read') - return None - except Exception: - logger.exception(f'Error decoding source file ({source.filename})', stack_info=True) - return None + if source.title.endswith('.pot'): + raise IndexingException('PowerPoint template files (.pot) are not supported') + + if isinstance(source.content, str): + io_obj = BytesIO(source.content.encode('utf-8', 'ignore')) + else: + io_obj = source.content + + if _loader_map.get(source.type): + result = _loader_map[source.type](io_obj) + return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore').strip() + + return io_obj.read().decode('utf-8', 'ignore').strip() + except IndexingException: + raise + except PdfFileNotDecryptedError as e: + raise IndexingException('PDF file is encrypted and cannot be read') from e + except Exception as e: + raise IndexingException(f'Error decoding source file: {e}') from e finally: - source.file.close() # Ensure file is closed after processing + if io_obj is not None: + io_obj.close() diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index 5871ebb8..ad2777ed 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -2,65 +2,240 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +import asyncio import logging import re +from collections.abc import Mapping +from io import BytesIO +from time import perf_counter_ns -from fastapi.datastructures import UploadFile +import niquests from langchain.schema import Document +from nc_py_api import AsyncNextcloudApp from ...dyn_loader import VectorDBLoader -from ...types import TConfig -from ...utils import is_valid_source_id, to_int +from ...types import IndexingError, IndexingException, ReceivedFileItem, SourceItem, TConfig from ...vectordb.base import BaseVectorDB from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp from ..types import InDocument from .doc_loader import decode_source from .doc_splitter import get_splitter_for -from .mimetype_list import SUPPORTED_MIMETYPES logger = logging.getLogger('ccb.injest') -def _allowed_file(file: UploadFile) -> bool: - return file.headers['type'] in SUPPORTED_MIMETYPES +MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, all loaded in RAM at once + + +async def __fetch_file_content( + semaphore: asyncio.Semaphore, + file_id: int, + user_id: str, + _rlimit = 3, +) -> BytesIO: + ''' + Raises + ------ + IndexingException + ''' + + async with semaphore: + nc = AsyncNextcloudApp() + try: + # a file pointer for storing the stream in memory until it is consumed + fp = BytesIO() + await nc._session.download2fp( + url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}', + fp=fp, + dav=False, + params={ 'userId': user_id }, + ) + fp.seek(0) + return fp + except niquests.exceptions.RequestException as e: + if e.response is None: + raise + + if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue] + # todo: implement rate limits in php CC? + wait_for = int(e.response.headers.get('Retry-After', '30')) + if _rlimit <= 0: + raise IndexingException( + f'Rate limited when fetching content for file id {file_id}, user id {user_id},' + ' max retries exceeded', + retryable=True, + ) from e + logger.warning( + f'Rate limited when fetching content for file id {file_id}, user id {user_id},' + f' waiting {wait_for} before retrying', + exc_info=e, + ) + await asyncio.sleep(wait_for) + return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1) + + raise + except IndexingException: + raise + except Exception as e: + logger.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e) + raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e + + +async def __fetch_files_content( + sources: Mapping[int, SourceItem | ReceivedFileItem], + concurrent_file_fetches: int, +) -> tuple[Mapping[int, SourceItem], Mapping[int, IndexingError]]: + source_items = {} + error_items = {} + tasks = [] + task_sources = {} + semaphore = asyncio.Semaphore(concurrent_file_fetches) + + file_count = sum(1 for s in sources.values() if isinstance(s, ReceivedFileItem)) + logger.debug('Fetching content for %d file(s) (max %d concurrent)', file_count, concurrent_file_fetches) + + for db_id, file in sources.items(): + if isinstance(file, SourceItem): + continue + + try: + # to detect any validation errors but it should not happen since file.reference is validated + file.file_id # noqa: B018 + except ValueError as e: + logger.error( + f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}', + exc_info=e, + ) + error_items[db_id] = IndexingError( + error=f'Invalid file reference format: {file.reference}', + retryable=False, + ) + continue + + if file.size > MAX_FILE_SIZE: + logger.info( + f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size' + f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB', + ) + error_items[db_id] = IndexingError( + error=( + f'File size {(file.size/(1024*1024)):.2f} MiB' + f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB' + ), + retryable=False, + ) + continue + # any user id from the list should have read access to the file + tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0]))) + task_sources[db_id] = file + + results = await asyncio.gather(*tasks, return_exceptions=True) + for (db_id, file), result in zip(task_sources.items(), results, strict=True): + if isinstance(result, str) or isinstance(result, BytesIO): + source_items[db_id] = SourceItem( + **{ + **file.model_dump(), + 'content': result, + } + ) + elif isinstance(result, IndexingException): + logger.error( + f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}' + f': {result}', + exc_info=result, + ) + error_items[db_id] = IndexingError( + error=str(result), + retryable=result.retryable, + ) + elif isinstance(result, BaseException): + logger.error( + f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},' + f' reference {file.reference}: {result}', + exc_info=result, + ) + error_items[db_id] = IndexingError( + error=f'Unexpected error: {result}', + retryable=True, + ) + else: + logger.error( + f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}' + f': {result}', + exc_info=True, + ) + error_items[db_id] = IndexingError( + error='Unknown error', + retryable=True, + ) + + # add the content providers from the orginal "sources" to the result unprocessed + for db_id, source in sources.items(): + if isinstance(source, SourceItem): + source_items[db_id] = source + + return source_items, error_items def _filter_sources( vectordb: BaseVectorDB, - sources: list[UploadFile] -) -> tuple[list[UploadFile], list[UploadFile]]: + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]: ''' Returns ------- - tuple[list[str], list[UploadFile]] + tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]: First value is a list of sources that already exist in the vectordb. Second value is a list of sources that are new and should be embedded. ''' try: - existing_sources, new_sources = vectordb.check_sources(sources) + existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources) except Exception as e: - raise DbException('Error: Vectordb sources_to_embed error') from e + raise DbException('Error: Vectordb error while checking existing sources in indexing') from e + + existing_sources = {} + to_embed_sources = {} + + for db_id, source in sources.items(): + if source.reference in existing_source_ids: + existing_sources[db_id] = source + elif source.reference in to_embed_source_ids: + to_embed_sources[db_id] = source - return ([ - source for source in sources - if source.filename in existing_sources - ], [ - source for source in sources - if source.filename in new_sources - ]) + return existing_sources, to_embed_sources -def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]: - indocuments = [] +def _sources_to_indocuments( + config: TConfig, + sources: Mapping[int, SourceItem] +) -> tuple[Mapping[int, InDocument], Mapping[int, IndexingError]]: + indocuments = {} + errored_docs = {} - for source in sources: - logger.debug('processing source', extra={ 'source_id': source.filename }) + for db_id, source in sources.items(): + logger.debug('processing source', extra={ 'source_id': source.reference }) # transform the source to have text data - content = decode_source(source) + try: + logger.debug('Decoding source %s (type: %s)', source.reference, source.type) + t0 = perf_counter_ns() + content = decode_source(source) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.debug('Decoded source %s in %.2f ms (%d chars)', source.reference, elapsed_ms, len(content)) + except IndexingException as e: + logger.error(f'Error decoding source ({source.reference}): {e}', exc_info=e) + errored_docs[db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue - if content is None or (content := content.strip()) == '': - logger.debug('decoded empty source', extra={ 'source_id': source.filename }) + if content == '': + logger.debug('decoded empty source', extra={ 'source_id': source.reference }) + errored_docs[db_id] = IndexingError( + error='Decoded content is empty', + retryable=False, + ) continue # replace more than two newlines with two newlines (also blank spaces, more than 4) @@ -68,97 +243,151 @@ def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[ # NOTE: do not use this with all docs when programming files are added content = re.sub(r'(\s){5,}', r'\g<1>', content) # filter out null bytes - content = content.replace('\0', '') - - if content is None or content == '': - logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.filename }) + content = content.replace('\0', '').strip() + + if content == '': + logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.reference }) + errored_docs[db_id] = IndexingError( + error='Cleaned up content is empty', + retryable=False, + ) continue - logger.debug('decoded non empty source', extra={ 'source_id': source.filename }) + logger.debug('decoded non empty source', extra={ 'source_id': source.reference }) metadata = { - 'source': source.filename, - 'title': _decode_latin_1(source.headers['title']), - 'type': source.headers['type'], + 'source': source.reference, + 'title': _decode_latin_1(source.title), + 'type': source.type, } doc = Document(page_content=content, metadata=metadata) - splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type']) + splitter = get_splitter_for(config.embedding_chunk_size, source.type) split_docs = splitter.split_documents([doc]) logger.debug('split document into chunks', extra={ - 'source_id': source.filename, + 'source_id': source.reference, 'len(split_docs)': len(split_docs), }) - indocuments.append(InDocument( + indocuments[db_id] = InDocument( documents=split_docs, - userIds=list(map(_decode_latin_1, source.headers['userIds'].split(','))), - source_id=source.filename, # pyright: ignore[reportArgumentType] - provider=source.headers['provider'], - modified=to_int(source.headers['modified']), - )) + userIds=list(map(_decode_latin_1, source.userIds)), + source_id=source.reference, + provider=source.provider, + modified=source.modified, # pyright: ignore[reportArgumentType] + ) - return indocuments + return indocuments, errored_docs + + +def _increase_access_for_existing_sources( + vectordb: BaseVectorDB, + existing_sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> Mapping[int, IndexingError | None]: + ''' + update userIds for existing sources + allow the userIds as additional users, not as the only users + ''' + if len(existing_sources) == 0: + return {} + + results = {} + logger.debug('Increasing access for existing sources', extra={ + 'source_ids': [source.reference for source in existing_sources.values()] + }) + for db_id, source in existing_sources.items(): + try: + vectordb.update_access( + UpdateAccessOp.ALLOW, + list(map(_decode_latin_1, source.userIds)), + source.reference, + ) + results[db_id] = None + except SafeDbException as e: + logger.error(f'Failed to update access for source ({source.reference}): {e.args[0]}') + results[db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue + except Exception as e: + logger.error(f'Unexpected error while updating access for source ({source.reference}): {e}') + results[db_id] = IndexingError( + error='Unexpected error while updating access', + retryable=True, + ) + continue + return results def _process_sources( vectordb: BaseVectorDB, config: TConfig, - sources: list[UploadFile], -) -> tuple[list[str],list[str]]: + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> Mapping[int, IndexingError | None]: ''' Processes the sources and adds them to the vectordb. Returns the list of source ids that were successfully added and those that need to be retried. ''' - existing_sources, filtered_sources = _filter_sources(vectordb, sources) + existing_sources, to_embed_sources = _filter_sources(vectordb, sources) logger.debug('db filter source results', extra={ 'len(existing_sources)': len(existing_sources), 'existing_sources': existing_sources, - 'len(filtered_sources)': len(filtered_sources), - 'filtered_sources': filtered_sources, + 'len(to_embed_sources)': len(to_embed_sources), + 'to_embed_sources': to_embed_sources, }) - loaded_source_ids = [source.filename for source in existing_sources] - # update userIds for existing sources - # allow the userIds as additional users, not as the only users - if len(existing_sources) > 0: - logger.debug('Increasing access for existing sources', extra={ - 'source_ids': [source.filename for source in existing_sources] - }) - for source in existing_sources: - try: - vectordb.update_access( - UpdateAccessOp.allow, - list(map(_decode_latin_1, source.headers['userIds'].split(','))), - source.filename, # pyright: ignore[reportArgumentType] - ) - except SafeDbException as e: - logger.error(f'Failed to update access for source ({source.filename}): {e.args[0]}') - continue - - if len(filtered_sources) == 0: + source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources) + + logger.debug( + 'Fetching file contents for %d source(s) from Nextcloud', + len(to_embed_sources), + ) + t0 = perf_counter_ns() + populated_to_embed_sources, errored_sources = asyncio.run( + __fetch_files_content(to_embed_sources, config.concurrent_file_fetches) + ) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.debug( + 'File content fetch complete in %.2f ms: %d fetched, %d errored', + elapsed_ms, len(populated_to_embed_sources), len(errored_sources), + ) + source_proc_results.update(errored_sources) # pyright: ignore[reportAttributeAccessIssue] + + if len(populated_to_embed_sources) == 0: # no new sources to embed logger.debug('Filtered all sources, nothing to embed') - return loaded_source_ids, [] # pyright: ignore[reportReturnType] + return source_proc_results logger.debug('Filtered sources:', extra={ - 'source_ids': [source.filename for source in filtered_sources] + 'source_ids': [source.reference for source in populated_to_embed_sources.values()] }) # invalid/empty sources are filtered out here and not counted in loaded/retryable - indocuments = _sources_to_indocuments(config, filtered_sources) + indocuments, errored_docs = _sources_to_indocuments(config, populated_to_embed_sources) - logger.debug('Converted all sources to documents') + source_proc_results.update(errored_docs) # pyright: ignore[reportAttributeAccessIssue] + logger.debug('Converted sources to documents') if len(indocuments) == 0: # filtered document(s) were invalid/empty, not an error logger.debug('All documents were found empty after being processed') - return loaded_source_ids, [] # pyright: ignore[reportReturnType] + return source_proc_results + + logger.debug('Adding documents to vectordb', extra={ + 'source_ids': [indoc.source_id for indoc in indocuments.values()] + }) - added_source_ids, retry_source_ids = vectordb.add_indocuments(indocuments) - loaded_source_ids.extend(added_source_ids) + t0 = perf_counter_ns() + doc_add_results = vectordb.add_indocuments(indocuments) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.info( + 'vectordb.add_indocuments completed in %.2f ms for %d document(s)', + elapsed_ms, len(indocuments), + ) + source_proc_results.update(doc_add_results) # pyright: ignore[reportAttributeAccessIssue] logger.debug('Added documents to vectordb') - return loaded_source_ids, retry_source_ids # pyright: ignore[reportReturnType] + return source_proc_results def _decode_latin_1(s: str) -> str: @@ -172,31 +401,15 @@ def _decode_latin_1(s: str) -> str: def embed_sources( vectordb_loader: VectorDBLoader, config: TConfig, - sources: list[UploadFile], -) -> tuple[list[str],list[str]]: - # either not a file or a file that is allowed - sources_filtered = [ - source for source in sources - if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType] - or _allowed_file(source) - ] - + sources: Mapping[int, SourceItem | ReceivedFileItem] +) -> Mapping[int, IndexingError | None]: logger.debug('Embedding sources:', extra={ 'source_ids': [ - f'{source.filename} ({_decode_latin_1(source.headers["title"])})' - for source in sources_filtered - ], - 'invalid_source_ids': [ - source.filename for source in sources - if not is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType] - ], - 'not_allowed_file_ids': [ - source.filename for source in sources - if not _allowed_file(source) + f'{source.reference} ({_decode_latin_1(source.title)})' + for source in sources.values() ], - 'len(source_ids)': len(sources_filtered), - 'len(total_source_ids)': len(sources), + 'len(source_ids)': len(sources), }) vectordb = vectordb_loader.load() - return _process_sources(vectordb, config, sources_filtered) + return _process_sources(vectordb, config, sources) diff --git a/context_chat_backend/chain/one_shot.py b/context_chat_backend/chain/one_shot.py index 1c0521bf..c3876217 100644 --- a/context_chat_backend/chain/one_shot.py +++ b/context_chat_backend/chain/one_shot.py @@ -10,39 +10,27 @@ from ..types import TConfig from .context import get_context_chunks, get_context_docs from .query_proc import get_pruned_query -from .types import ContextException, LLMOutput, ScopeType +from .types import ContextException, LLMOutput, ScopeType, SearchResult -_LLM_TEMPLATE = '''Answer based only on this context and do not add any imaginative details. Make sure to use the same language as the question in your answer. +_LLM_TEMPLATE = '''You're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. +Use the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents. + +START OF CONTEXT: {context} -{question} -''' # noqa: E501 +END OF CONTEXT! -logger = logging.getLogger('ccb.chain') +If you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. +Don't mention the context in your answer but rather just answer the question directly. +Detect the language of the question and make sure to use the same language that was used in the question to answer the question. +Don't mention which language was used, but just answer the question directly in the same langauge. -def process_query( - user_id: str, - llm: LLM, - app_config: TConfig, - query: str, - no_ctx_template: str | None = None, - end_separator: str = '', -): - """ - Raises - ------ - ValueError - If the context length is too small to fit the query - """ - stop = [end_separator] if end_separator else None - output = llm.invoke( - (query, get_pruned_query(llm, app_config, query, no_ctx_template, []))[no_ctx_template is not None], # pyright: ignore[reportArgumentType] - stop=stop, - userid=user_id, - ).strip() +Question: {question} - return LLMOutput(output=output, sources=[]) +Let's think this step-by-step. +''' # noqa: E501 +logger = logging.getLogger('ccb.chain') def process_context_query( user_id: str, @@ -54,7 +42,6 @@ def process_context_query( scope_type: ScopeType | None = None, scope_list: list[str] | None = None, template: str | None = None, - end_separator: str = '', ): """ Raises @@ -75,9 +62,11 @@ def process_context_query( output = llm.invoke( get_pruned_query(llm, app_config, query, template or _LLM_TEMPLATE, context_chunks), - stop=[end_separator], userid=user_id, ).strip() - unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))}) + unique_sources = [SearchResult( + source_id=source, + title=d.metadata.get('title', ''), + ) for d in context_docs if (source := d.metadata.get('source'))] return LLMOutput(output=output, sources=unique_sources) diff --git a/context_chat_backend/chain/types.py b/context_chat_backend/chain/types.py index b006ad1a..3afdf297 100644 --- a/context_chat_backend/chain/types.py +++ b/context_chat_backend/chain/types.py @@ -33,12 +33,24 @@ class ContextException(Exception): ... +class SearchResult(TypedDict): + source_id: str + title: str + + class LLMOutput(TypedDict): output: str - sources: list[str] - # todo: add "titles" field + sources: list[SearchResult] -class SearchResult(TypedDict): - source_id: str - title: str +class EnrichedSource(BaseModel): + id: str + label: str + icon: str + url: str + +class EnrichedSourceList(BaseModel): + sources: list[EnrichedSource] + +class ScopeList(BaseModel): + source_ids: list[str] diff --git a/context_chat_backend/config_parser.py b/context_chat_backend/config_parser.py index dafef75f..0a62019a 100644 --- a/context_chat_backend/config_parser.py +++ b/context_chat_backend/config_parser.py @@ -103,17 +103,11 @@ def get_config(file_path: str) -> TConfig: except Exception as e: raise AssertionError('Error: could not create embedding config from config file') from e - return TConfig( - debug=config.get('debug', False), - uvicorn_log_level=config.get('uvicorn_log_level', 'info'), - disable_aaa=config.get('disable_aaa', False), - verify_ssl=config.get('verify_ssl', config.get('httpx_verify_ssl', True)), - use_colors=config.get('use_colors', True), - uvicorn_workers=config.get('uvicorn_workers', 1), - embedding_chunk_size=config.get('embedding_chunk_size', 1000), - doc_parser_worker_limit=config.get('doc_parser_worker_limit', 10), - - vectordb=vectordb, - embedding=embedding_config, - llm=llm, - ) + config['verify_ssl'] = config.get('verify_ssl', config.get('httpx_verify_ssl', True)) + config.pop('httpx_verify_ssl', None) + + config['llm'] = llm + config['vectordb'] = vectordb + config['embedding'] = embedding_config + + return TConfig(**config) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index c26b930a..9c3812e9 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider # isort: off -from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult +from .chain.types import ContextException from .types import LoaderException, EmbeddingException -from .vectordb.types import DbException, SafeDbException, UpdateAccessOp +from .vectordb.types import DbException, SafeDbException from .setup_functions import ensure_config_file, repair_run, setup_env_vars # setup env vars before importing other modules @@ -23,39 +24,29 @@ from collections.abc import Callable from contextlib import asynccontextmanager from functools import wraps -from threading import Event, Thread -from time import sleep -from typing import Annotated, Any -from fastapi import Body, FastAPI, Request, UploadFile -from langchain.llms.base import LLM +from fastapi import FastAPI, Request from nc_py_api import AsyncNextcloudApp, NextcloudApp from nc_py_api.ex_app import persistent_storage, set_handlers -from pydantic import BaseModel, ValidationInfo, field_validator from starlette.responses import FileResponse -from .chain.context import do_doc_search -from .chain.ingest.injest import embed_sources -from .chain.one_shot import process_context_query, process_query from .config_parser import get_config -from .dyn_loader import LLMModelLoader, VectorDBLoader +from .dyn_loader import VectorDBLoader from .models.types import LlmException from nc_py_api.ex_app import AppAPIAuthMiddleware -from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of -from .vectordb.service import ( - count_documents_by_provider, - decl_update_access, - delete_by_provider, - delete_by_source, - delete_user, - update_access, -) +from .utils import JSONResponse, exec_in_proc +from .task_fetcher import start_bg_threads, trigger_handler, wait_for_bg_threads +from .vectordb.service import count_documents_by_provider # setup -repair_run() -ensure_config_file() +# only run once +if mp.current_process().name == 'MainProcess': + repair_run() + ensure_config_file() + logger = logging.getLogger('ccb.controller') +app_config = get_config(os.environ['CC_CONFIG_PATH']) __download_models_from_hf = os.environ.get('CC_DOWNLOAD_MODELS_FROM_HF', 'true').lower() in ('1', 'true', 'yes') models_to_fetch = { @@ -70,13 +61,33 @@ 'revision': '607a30d783dfa663caf39e06633721c8d4cfcd7e', } } if __download_models_from_hf else {} -app_enabled = Event() +app_enabled = threading.Event() -def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: - if enabled: - app_enabled.set() - else: - app_enabled.clear() +def enabled_handler(enabled: bool, nc: NextcloudApp | AsyncNextcloudApp) -> str: + try: + if enabled: + provider = TaskProcessingProvider( + id="context_chat-context_chat_search", + name="Context Chat", + task_type="context_chat:context_chat_search", + expected_runtime=30, + ) + nc.providers.task_processing.register(provider) + provider = TaskProcessingProvider( + id="context_chat-context_chat", + name="Context Chat", + task_type="context_chat:context_chat", + expected_runtime=30, + ) + nc.providers.task_processing.register(provider) + app_enabled.set() + start_bg_threads(app_config, app_enabled) + else: + app_enabled.clear() + wait_for_bg_threads() + except Exception as e: + logger.exception('Error in enabled handler:', exc_info=e) + return f'Error in enabled handler: {e}' logger.info(f'App {("disabled", "enabled")[enabled]}') return '' @@ -84,19 +95,17 @@ def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: @asynccontextmanager async def lifespan(app: FastAPI): - set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch) + set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch, trigger_handler=trigger_handler) nc = NextcloudApp() if nc.enabled_state: app_enabled.set() + start_bg_threads(app_config, app_enabled) logger.info(f'App enable state at startup: {app_enabled.is_set()}') - t = Thread(target=background_thread_task, args=()) - t.start() yield vectordb_loader.offload() - llm_loader.offload() + wait_for_bg_threads() -app_config = get_config(os.environ['CC_CONFIG_PATH']) app = FastAPI(debug=app_config.debug, lifespan=lifespan) # pyright: ignore[reportArgumentType] app.extra['CONFIG'] = app_config @@ -105,7 +114,6 @@ async def lifespan(app: FastAPI): # loaders vectordb_loader = VectorDBLoader(app_config) -llm_loader = LLMModelLoader(app, app_config) # locks and semaphores @@ -117,22 +125,12 @@ async def lifespan(app: FastAPI): index_lock = threading.Lock() _indexing = {} -# limit the number of concurrent document parsing -doc_parse_semaphore = mp.Semaphore(app_config.doc_parser_worker_limit) - # middlewares if not app_config.disable_aaa: app.add_middleware(AppAPIAuthMiddleware) -# logger background thread - -def background_thread_task(): - while(True): - logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing}) - sleep(10) - # exception handlers @app.exception_handler(DbException) @@ -213,121 +211,6 @@ def _(): return JSONResponse(content={'enabled': app_enabled.is_set()}, status_code=200) -@app.post('/updateAccessDeclarative') -@enabled_guard(app) -def _( - userIds: Annotated[list[str], Body()], - sourceId: Annotated[str, Body()], -): - logger.debug('Update access declarative request:', extra={ - 'user_ids': userIds, - 'source_id': sourceId, - }) - - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - - if not is_valid_source_id(sourceId): - return JSONResponse('Invalid source id', 400) - - exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId)) - - return JSONResponse('Access updated') - - -@app.post('/updateAccess') -@enabled_guard(app) -def _( - op: Annotated[UpdateAccessOp, Body()], - userIds: Annotated[list[str], Body()], - sourceId: Annotated[str, Body()], -): - logger.debug('Update access request', extra={ - 'op': op, - 'user_ids': userIds, - 'source_id': sourceId, - }) - - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - - if not is_valid_source_id(sourceId): - return JSONResponse('Invalid source id', 400) - - exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId)) - - return JSONResponse('Access updated') - - -@app.post('/updateAccessProvider') -@enabled_guard(app) -def _( - op: Annotated[UpdateAccessOp, Body()], - userIds: Annotated[list[str], Body()], - providerId: Annotated[str, Body()], -): - logger.debug('Update access by provider request', extra={ - 'op': op, - 'user_ids': userIds, - 'provider_id': providerId, - }) - - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - - if not is_valid_provider_id(providerId): - return JSONResponse('Invalid provider id', 400) - - exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, providerId)) - - return JSONResponse('Access updated') - - -@app.post('/deleteSources') -@enabled_guard(app) -def _(sourceIds: Annotated[list[str], Body(embed=True)]): - logger.debug('Delete sources request', extra={ - 'source_ids': sourceIds, - }) - - sourceIds = [source.strip() for source in sourceIds if source.strip() != ''] - - if len(sourceIds) == 0: - return JSONResponse('No sources provided', 400) - - res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds)) - if res is False: - return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400) - - return JSONResponse('All valid sources deleted') - - -@app.post('/deleteProvider') -@enabled_guard(app) -def _(providerKey: str = Body(embed=True)): - logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey }) - - if value_of(providerKey) is None: - return JSONResponse('Invalid provider key provided', 400) - - exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey)) - - return JSONResponse('All valid sources deleted') - - -@app.post('/deleteUser') -@enabled_guard(app) -def _(userId: str = Body(embed=True)): - logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId }) - - if value_of(userId) is None: - return JSONResponse('Invalid userId provided', 400) - - exec_in_proc(target=delete_user, args=(vectordb_loader, userId)) - - return JSONResponse('User deleted') - - @app.post('/countIndexedDocuments') @enabled_guard(app) def _(): @@ -335,177 +218,6 @@ def _(): return JSONResponse(counts) -@app.put('/loadSources') -@enabled_guard(app) -def _(sources: list[UploadFile]): - global _indexing - - if len(sources) == 0: - return JSONResponse('No sources provided', 400) - - filtered_sources = [] - - for source in sources: - if not value_of(source.filename): - logger.warning('Skipping source with invalid source_id', extra={ - 'source_id': source.filename, - 'title': source.headers.get('title'), - }) - continue - - with index_lock: - if source.filename in _indexing: - # this request will be retried by the client - return JSONResponse( - f'This source ({source.filename}) is already being processed in another request, try again later', - 503, - headers={'cc-retry': 'true'}, - ) - - if not ( - value_of(source.headers.get('userIds')) - and source.headers.get('title', None) is not None - and value_of(source.headers.get('type')) - and value_of(source.headers.get('modified')) - and source.headers['modified'].isdigit() - and value_of(source.headers.get('provider')) - ): - logger.warning('Skipping source with invalid/missing headers', extra={ - 'source_id': source.filename, - 'title': source.headers.get('title'), - 'headers': source.headers, - }) - continue - - filtered_sources.append(source) - - # wait for 10 minutes before failing the request - semres = doc_parse_semaphore.acquire(block=True, timeout=10*60) - if not semres: - return JSONResponse( - 'Document parser worker limit reached, try again in some time or consider increasing the limit', - 503, - headers={'cc-retry': 'true'} - ) - - with index_lock: - for source in filtered_sources: - _indexing[source.filename] = source.size - - try: - loaded_sources, not_added_sources = exec_in_proc( - target=embed_sources, - args=(vectordb_loader, app.extra['CONFIG'], filtered_sources) - ) - except (DbException, EmbeddingException): - raise - except Exception as e: - raise DbException('Error: failed to load sources') from e - finally: - with index_lock: - for source in filtered_sources: - _indexing.pop(source.filename, None) - doc_parse_semaphore.release() - - if len(loaded_sources) != len(filtered_sources): - logger.debug('Some sources were not loaded', extra={ - 'Count of loaded sources': f'{len(loaded_sources)}/{len(filtered_sources)}', - 'source_ids': loaded_sources, - }) - - # loaded sources include the existing sources that may only have their access updated - return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources}) - - -class Query(BaseModel): - userId: str - query: str - useContext: bool = True - scopeType: ScopeType | None = None - scopeList: list[str] | None = None - ctxLimit: int = 20 - - @field_validator('userId', 'query', 'ctxLimit') - @classmethod - def check_empty_values(cls, value: Any, info: ValidationInfo): - if value_of(value) is None: - raise ValueError('Empty value for field', info.field_name) - - return value - - @field_validator('ctxLimit') - @classmethod - def at_least_one_context(cls, value: int): - if value < 1: - raise ValueError('Invalid context chunk limit') - - return value - - -def execute_query(query: Query, in_proc: bool = True) -> LLMOutput: - llm: LLM = llm_loader.load() - template = app.extra.get('LLM_TEMPLATE') - no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE'] - # todo: array - end_separator = app.extra.get('LLM_END_SEPARATOR', '') - - if query.useContext: - target = process_context_query - args=( - query.userId, - vectordb_loader, - llm, - app_config, - query.query, - query.ctxLimit, - query.scopeType, - query.scopeList, - template, - end_separator, - ) - else: - target=process_query - args=( - query.userId, - llm, - app_config, - query.query, - no_ctx_template, - end_separator, - ) - - if in_proc: - return exec_in_proc(target=target, args=args) - - return target(*args) # pyright: ignore - - -@app.post('/query') -@enabled_guard(app) -def _(query: Query) -> LLMOutput: - logger.debug('received query request', extra={ 'query': query.dict() }) - - if app_config.llm[0] == 'nc_texttotext': - return execute_query(query) - - with llm_lock: - return execute_query(query, in_proc=False) - - -@app.post('/docSearch') -@enabled_guard(app) -def _(query: Query) -> list[SearchResult]: - # useContext from Query is not used here - return exec_in_proc(target=do_doc_search, args=( - query.userId, - query.query, - vectordb_loader, - query.ctxLimit, - query.scopeType, - query.scopeList, - )) - - @app.get('/downloadLogs') def download_logs() -> FileResponse: with tempfile.NamedTemporaryFile('wb', delete=False) as tmp: diff --git a/context_chat_backend/dyn_loader.py b/context_chat_backend/dyn_loader.py index d67310ff..47b19575 100644 --- a/context_chat_backend/dyn_loader.py +++ b/context_chat_backend/dyn_loader.py @@ -7,11 +7,9 @@ import gc import logging from abc import ABC, abstractmethod -from time import time from typing import Any import torch -from fastapi import FastAPI from langchain.llms.base import LLM from .models.loader import init_model @@ -54,19 +52,11 @@ def offload(self) -> None: class LLMModelLoader(Loader): - def __init__(self, app: FastAPI, config: TConfig) -> None: + def __init__(self, config: TConfig) -> None: self.config = config - self.app = app def load(self) -> LLM: - if self.app.extra.get('LLM_MODEL') is not None: - self.app.extra['LLM_LAST_ACCESSED'] = time() - return self.app.extra['LLM_MODEL'] - llm_name, llm_config = self.config.llm - self.app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '') - self.app.extra['LLM_NO_CTX_TEMPLATE'] = llm_config.pop('no_ctx_template', '') - self.app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '') try: model = init_model('llm', (llm_name, llm_config)) @@ -75,13 +65,9 @@ def load(self) -> LLM: if not isinstance(model, LLM): raise LoaderException(f'Error: {model} does not implement "llm" type or has returned an invalid object') - self.app.extra['LLM_MODEL'] = model - self.app.extra['LLM_LAST_ACCESSED'] = time() return model def offload(self) -> None: - if self.app.extra.get('LLM_MODEL') is not None: - del self.app.extra['LLM_MODEL'] clear_cache() diff --git a/context_chat_backend/chain/ingest/mimetype_list.py b/context_chat_backend/mimetype_list.py similarity index 100% rename from context_chat_backend/chain/ingest/mimetype_list.py rename to context_chat_backend/mimetype_list.py diff --git a/context_chat_backend/network_em.py b/context_chat_backend/network_em.py index 18bb11f4..5ba8faf5 100644 --- a/context_chat_backend/network_em.py +++ b/context_chat_backend/network_em.py @@ -3,14 +3,16 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # import logging +import socket from time import sleep from typing import Literal, TypedDict +from urllib.parse import urlparse import niquests from langchain_core.embeddings import Embeddings -from pydantic import BaseModel from .types import ( + DocErrorEmbeddingException, EmbeddingException, FatalEmbeddingException, RetryableEmbeddingException, @@ -20,6 +22,7 @@ ) logger = logging.getLogger('ccb.nextwork_em') +TCP_CONNECT_TIMEOUT = 2.0 # seconds # Copied from llama_cpp/llama_types.py @@ -41,8 +44,35 @@ class CreateEmbeddingResponse(TypedDict): usage: EmbeddingUsage -class NetworkEmbeddings(Embeddings, BaseModel): - app_config: TConfig +class NetworkEmbeddings(Embeddings): + def __init__(self, app_config: TConfig): + self.app_config = app_config + + def _get_host_and_port(self) -> tuple[str, int]: + parsed = urlparse(self.app_config.embedding.base_url) + host = parsed.hostname + + if not host: + raise ValueError("Invalid URL: Missing hostname") + + if parsed.port: + port = parsed.port + else: + port = 443 if parsed.scheme == "https" else 80 + + return host, port + + def check_connection(self, check_origin: str) -> bool: + try: + host, port = self._get_host_and_port() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(TCP_CONNECT_TIMEOUT) + sock.connect((host, port)) + sock.close() + return True + except (ValueError, TimeoutError, ConnectionRefusedError, socket.gaierror) as e: + logger.warning(f'[{check_origin}] Embedding server is not reachable, retrying after some time: {e}') + return False def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] | list[list[float]]: emconf = self.app_config.embedding @@ -76,13 +106,27 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] if response.status_code is None: raise EmbeddingException('Error: no response from embedding service') if response.status_code // 100 == 4: - raise FatalEmbeddingException(response.text) + raise FatalEmbeddingException( + response.text or f'Error: embedding request returned non-2xx status code {response.status_code}', + ) if response.status_code // 100 != 2: - raise EmbeddingException(response.text) + raise EmbeddingException( + response.text or f'Error: embedding request returned non-2xx status code {response.status_code}', + response, + ) except FatalEmbeddingException as e: logger.error('Fatal error while getting embeddings: %s', str(e), exc_info=e) raise e except EmbeddingException as e: + try: + if e.response is not None: + err_msg = e.response.json().get('error', {}).get('message', '') + if err_msg == 'llama_decode returned -1': + # the document coult not be processed + raise DocErrorEmbeddingException(f'Failed to embed the document: {err_msg}') from e + except niquests.exceptions.JSONDecodeError: + ... + if try_ > 0: logger.debug('Retrying embedding request in 5 secs', extra={'try': try_}) sleep(5) @@ -108,10 +152,14 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] logger.error('Unexpected error while getting embeddings', exc_info=e) raise EmbeddingException('Error: unexpected error while getting embeddings') from e - # converts TypedDict to a pydantic model - resp = CreateEmbeddingResponse(**response.json()) - if isinstance(input_, str): - return resp['data'][0]['embedding'] + try: + # converts TypedDict to a pydantic model + resp = CreateEmbeddingResponse(**response.json()) + if isinstance(input_, str): + return resp['data'][0]['embedding'] + except Exception as e: + logger.error('Error parsing embedding response', exc_info=e) + raise EmbeddingException('Error: failed to parse embedding response') from e # only one embedding in d['embedding'] since truncate is True return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType] diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py new file mode 100644 index 00000000..d8fee7c0 --- /dev/null +++ b/context_chat_backend/task_fetcher.py @@ -0,0 +1,752 @@ +# +# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors +# SPDX-License-Identifier: AGPL-3.0-or-later +# +import logging +import math +import os +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress +from enum import Enum +from threading import Event, Thread +from time import sleep +from typing import Any + +import niquests +from langchain.llms.base import LLM +from nc_py_api import NextcloudApp, NextcloudException +from niquests import JSONDecodeError, RequestException +from pydantic import ValidationError + +from .chain.context import do_doc_search +from .chain.ingest.injest import embed_sources +from .chain.one_shot import process_context_query +from .chain.types import ContextException, EnrichedSourceList, LLMOutput, ScopeList, SearchResult +from .dyn_loader import LLMModelLoader, VectorDBLoader +from .network_em import NetworkEmbeddings +from .types import ( + ActionsQueueItems, + ActionType, + AppRole, + EmbeddingException, + FilesQueueItems, + IndexingError, + LoaderException, + ReceivedFileItem, + SourceItem, + TConfig, +) +from .utils import SubprocessKilledError, exec_in_proc, get_app_role +from .vectordb.service import ( + decl_update_access, + delete_by_provider, + delete_by_source, + delete_user, + update_access, + update_access_provider, +) +from .vectordb.types import DbException, SafeDbException + +APP_ROLE = get_app_role() +THREADS = {} +THREAD_STOP_EVENT = Event() +LOGGER = logging.getLogger('ccb.task_fetcher') +MIN_FILES_PER_CPU = 4 +POLLING_COOLDOWN = 30 + +# task processing or request processing +TP_TRIGGER = Event() +TP_CHECK_INTERVAL = 5 +TP_CHECK_INTERVAL_WITH_TRIGGER = 5 * 60 +TP_CHECK_INTERVAL_ON_ERROR = 15 +CONTEXT_LIMIT=20 + + +class ThreadType(Enum): + FILES_INDEXING = 'files_indexing' + UPDATES_PROCESSING = 'updates_processing' + REQUEST_PROCESSING = 'request_processing' + + +def files_indexing_thread(app_config: TConfig, app_enabled: Event) -> None: + try: + network_em = NetworkEmbeddings(app_config) + vectordb_loader = VectorDBLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + def _load_sources(source_items: Mapping[int, SourceItem | ReceivedFileItem]) -> Mapping[int, IndexingError | None]: + source_refs = [s.reference for s in source_items.values()] + LOGGER.info('Starting embed_sources subprocess for %d source(s)', len(source_items), extra={ + 'source_ids': source_refs, + }) + try: + result = exec_in_proc( + target=embed_sources, + args=(vectordb_loader, app_config, source_items), + ) + errors = {k: v for k, v in result.items() if isinstance(v, IndexingError)} + LOGGER.info( + 'embed_sources finished for %d source(s): %d succeeded, %d errored', + len(source_items), len(result) - len(errors), len(errors), + extra={'errors': errors}, + ) + return result + except SubprocessKilledError as e: + LOGGER.error( + 'embed_sources subprocess was killed for %d source(s) with exitcode %s', + len(source_items), e.exitcode, exc_info=e, extra={ + 'source_ids': source_refs, + }, + ) + if len(source_items) == 1: + return dict.fromkeys( + source_items, + IndexingError(error=f'Subprocess killed with exitcode {e.exitcode}: {e}', retryable=False), + ) + + # Fall back to one-by-one to isolate the problematic file. + LOGGER.warning( + 'Falling back to individual processing for %d sources', + len(source_items), + ) + fallback: dict[int, IndexingError | None] = {} + for db_id, item in source_items.items(): + fallback.update(_load_sources({db_id: item})) + return fallback + except Exception as e: + err = IndexingError( + error=f'{e.__class__.__name__}: {e}', + retryable=True, + ) + LOGGER.error( + 'embed_sources subprocess raised a %s error for %d sources, marking all as retryable', + e.__class__.__name__, len(source_refs), exc_info=e, extra={ + 'source_ids': source_refs, + } + ) + return dict.fromkeys(source_items, err) + + + # divides the batch into these many chunks + file_parsing_cpu_count = ( + app_config.file_parsing_cpu_count, # when set to a positive value + max(1, (os.cpu_count() or 2) - 1), # when set to auto (-1) + )[app_config.file_parsing_cpu_count == -1] + LOGGER.info(f'Using {file_parsing_cpu_count} parallel file parsing workers') + + while True: + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Files indexing thread is stopping due to stop event being set') + return + + try: + if not network_em.check_connection(ThreadType.FILES_INDEXING.value): + sleep(POLLING_COOLDOWN) + continue + + nc = NextcloudApp() + q_items_res = nc.ocs( + 'GET', + '/ocs/v2.php/apps/context_chat/queues/documents', + params={ 'n': app_config.doc_indexing_batch_size } + ) + + try: + q_items: FilesQueueItems = FilesQueueItems.model_validate(q_items_res) + except ValidationError as e: + raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + + if not q_items.files and not q_items.content_providers: + LOGGER.debug('No documents to index') + sleep(POLLING_COOLDOWN) + continue + + files_result = {} + providers_result = {} + + # chunk file parsing for better file operation parallelism + file_chunk_size = max(MIN_FILES_PER_CPU, math.ceil(len(q_items.files) / file_parsing_cpu_count)) + file_chunks = [ + dict(list(q_items.files.items())[i:i+file_chunk_size]) + for i in range(0, len(q_items.files), file_chunk_size) + ] + provider_chunk_size = max( + MIN_FILES_PER_CPU, + math.ceil(len(q_items.content_providers) / file_parsing_cpu_count), + ) + provider_chunks = [ + dict(list(q_items.content_providers.items())[i:i+provider_chunk_size]) + for i in range(0, len(q_items.content_providers), provider_chunk_size) + ] + + with ThreadPoolExecutor( + max_workers=file_parsing_cpu_count, + thread_name_prefix='IndexingPool', + ) as executor: + LOGGER.info( + 'Dispatching %d file chunk(s) and %d provider chunk(s)', + len(file_chunks), len(provider_chunks), + ) + file_futures = [executor.submit(_load_sources, chunk) for chunk in file_chunks] + provider_futures = [executor.submit(_load_sources, chunk) for chunk in provider_chunks] + + for i, future in enumerate(file_futures): + LOGGER.debug('Waiting for file chunk %d/%d future to complete', i + 1, len(file_futures)) + files_result.update(future.result()) + LOGGER.debug('File chunk %d/%d future completed', i + 1, len(file_futures)) + for i, future in enumerate(provider_futures): + LOGGER.debug('Waiting for provider chunk %d/%d future to complete', i + 1, len(provider_futures)) + providers_result.update(future.result()) + LOGGER.debug('Provider chunk %d/%d future completed', i + 1, len(provider_futures)) + + if ( + any(isinstance(res, IndexingError) for res in files_result.values()) + or any(isinstance(res, IndexingError) for res in providers_result.values()) + ): + LOGGER.error('Some sources failed to index', extra={ + 'file_errors': { + db_id: error + for db_id, error in files_result.items() + if isinstance(error, IndexingError) + }, + 'provider_errors': { + provider_id: error + for provider_id, error in providers_result.items() + if isinstance(error, IndexingError) + }, + }) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error fetching documents to index, will retry:', exc_info=e) + sleep(5) + continue + except Exception as e: + LOGGER.exception('Error fetching documents to index:', exc_info=e) + sleep(5) + continue + + # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable + to_delete_files_db_ids = [ + db_id for db_id, result in files_result.items() + if result is None or (isinstance(result, IndexingError) and not result.retryable) + ] + to_delete_provider_db_ids = [ + db_id for db_id, result in providers_result.items() + if result is None or (isinstance(result, IndexingError) and not result.retryable) + ] + + try: + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/documents/', + json={ + 'files': to_delete_files_db_ids, + 'content_providers': to_delete_provider_db_ids, + }, + ) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error reporting indexing results, will retry:', exc_info=e) + sleep(5) + with suppress(Exception): + nc = NextcloudApp() + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/documents/', + json={ + 'files': to_delete_files_db_ids, + 'content_providers': to_delete_provider_db_ids, + }, + ) + continue + except Exception as e: + LOGGER.exception('Error reporting indexing results:', exc_info=e) + sleep(5) + continue + + + +def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: + try: + vectordb_loader = VectorDBLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + while True: + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Updates processing thread is stopping due to stop event being set') + return + + try: + nc = NextcloudApp() + q_items_res = nc.ocs( + 'GET', + '/ocs/v2.php/apps/context_chat/queues/actions', + params={ 'n': app_config.actions_batch_size } + ) + + try: + q_items: ActionsQueueItems = ActionsQueueItems.model_validate(q_items_res) + except ValidationError as e: + raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error fetching updates to process, will retry:', exc_info=e) + sleep(5) + continue + except Exception as e: + LOGGER.exception('Error fetching updates to process:', exc_info=e) + sleep(5) + continue + + if not q_items.actions: + LOGGER.debug('No updates to process') + sleep(POLLING_COOLDOWN) + continue + + processed_event_ids = [] + errored_events = {} + for i, (db_id, action_item) in enumerate(q_items.actions.items()): + try: + match action_item.type: + case ActionType.DELETE_SOURCE_IDS: + exec_in_proc(target=delete_by_source, args=(vectordb_loader, action_item.payload.sourceIds)) + + case ActionType.DELETE_PROVIDER_ID: + exec_in_proc(target=delete_by_provider, args=(vectordb_loader, action_item.payload.providerId)) + + case ActionType.DELETE_USER_ID: + exec_in_proc(target=delete_user, args=(vectordb_loader, action_item.payload.userId)) + + case ActionType.UPDATE_ACCESS_SOURCE_ID: + exec_in_proc( + target=update_access, + args=( + vectordb_loader, + action_item.payload.op, + action_item.payload.userIds, + action_item.payload.sourceId, + ), + ) + + case ActionType.UPDATE_ACCESS_PROVIDER_ID: + exec_in_proc( + target=update_access_provider, + args=( + vectordb_loader, + action_item.payload.op, + action_item.payload.userIds, + action_item.payload.providerId, + ), + ) + + case ActionType.UPDATE_ACCESS_DECL_SOURCE_ID: + exec_in_proc( + target=decl_update_access, + args=( + vectordb_loader, + action_item.payload.userIds, + action_item.payload.sourceId, + ), + ) + + case _: + LOGGER.warning( + f'Unknown action type {action_item.type} for action id {db_id},' + f' type {action_item.type}, skipping and marking as processed', + extra={ 'action_item': action_item }, + ) + continue + + processed_event_ids.append(db_id) + except SafeDbException as e: + LOGGER.debug( + f'Safe DB error thrown while processing action id {db_id}, type {action_item.type},' + " it's safe to ignore and mark as processed.", + exc_info=e, + extra={ 'action_item': action_item }, + ) + processed_event_ids.append(db_id) + continue + + except (LoaderException, DbException) as e: + LOGGER.error( + f'Error deleting source for action id {db_id}, type {action_item.type}: {e}', + exc_info=e, + extra={ 'action_item': action_item }, + ) + errored_events[db_id] = str(e) + continue + + except Exception as e: + LOGGER.error( + f'Unexpected error processing action id {db_id}, type {action_item.type}: {e}', + exc_info=e, + extra={ 'action_item': action_item }, + ) + errored_events[db_id] = f'Unexpected error: {e}' + continue + + if (i + 1) % 20 == 0: + LOGGER.debug(f'Processed {i + 1} updates, sleeping for a bit to allow other operations to proceed') + sleep(2) + + LOGGER.info(f'Processed {len(processed_event_ids)} updates with {len(errored_events)} errors', extra={ + 'errored_events': errored_events, + }) + + if len(processed_event_ids) == 0: + LOGGER.debug('No updates processed, skipping reporting to the server') + continue + + try: + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/actions/', + json={ 'actions': processed_event_ids }, + ) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error reporting processed updates, will retry:', exc_info=e) + sleep(5) + with suppress(Exception): + nc = NextcloudApp() + nc.ocs( + 'DELETE', + '/ocs/v2.php/apps/context_chat/queues/actions/', + json={ 'ids': processed_event_ids }, + ) + continue + except Exception as e: + LOGGER.exception('Error reporting processed updates:', exc_info=e) + sleep(5) + continue + + +def resolve_scope_list(source_ids: list[str], userId: str) -> list[str]: + """ + + Parameters + ---------- + source_ids + + Returns + ------- + source_ids with only files, no folders (or source_ids in case of non-file provider) + """ + nc = NextcloudApp() + data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/resolve_scope_list', json={ + 'source_ids': source_ids, + 'userId': userId, + }) + return ScopeList.model_validate(data).source_ids + + +def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None: + LOGGER.info('Starting request processing thread') + + try: + network_em = NetworkEmbeddings(app_config) + vectordb_loader = VectorDBLoader(app_config) + llm_loader = LLMModelLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + nc = NextcloudApp() + llm: LLM = llm_loader.load() + + while True: + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Request processing thread is stopping due to stop event being set') + return + + if not network_em.check_connection(ThreadType.REQUEST_PROCESSING.value): + sleep(POLLING_COOLDOWN) + continue + + try: + # Fetch pending task + try: + response = nc.providers.task_processing.next_task( + ['context_chat-context_chat', 'context_chat-context_chat_search'], + ['context_chat:context_chat', 'context_chat:context_chat_search'], + ) + if not response: + wait_for_tasks() + continue + except (NextcloudException, RequestException, JSONDecodeError) as e: + LOGGER.error(f"Network error fetching the next task {e}", exc_info=e) + wait_for_tasks(TP_CHECK_INTERVAL_ON_ERROR) + continue + + # Process task + task = response["task"] + userId = task['userId'] + + try: + LOGGER.debug(f'Processing task {task["id"]}') + + if task['input'].get('scopeType') == 'source': + # Resolve scope list to only files, no folders + task['input']['scopeList'] = resolve_scope_list(task['input'].get('scopeList'), userId) + + if task['type'] == 'context_chat:context_chat': + result: LLMOutput = process_normal_task(task, vectordb_loader, llm, app_config) + # Return result to Nextcloud + success = return_result_to_nextcloud(task['id'], userId, { + 'output': result['output'], + 'sources': enrich_sources(result['sources'], userId), + }) + elif task['type'] == 'context_chat:context_chat_search': + search_result: list[SearchResult] = process_search_task(task, vectordb_loader) + # Return result to Nextcloud + success = return_result_to_nextcloud(task['id'], userId, { + 'sources': enrich_sources(search_result, userId), + }) + else: + LOGGER.error(f'Unknown task type {task["type"]}') + success = return_error_to_nextcloud(task['id'], Exception(f'Unknown task type {task["type"]}')) + + if success: + LOGGER.info(f'Task {task["id"]} completed successfully') + else: + LOGGER.error(f'Failed to return result for task {task["id"]}') + + except EmbeddingException as e: + LOGGER.warning(f'Embedding server error for task {task["id"]}: {e}') + return_error_to_nextcloud(task['id'], e) + except ContextException as e: + LOGGER.warning(f'Context error for task {task["id"]}: {e}') + return_error_to_nextcloud(task['id'], e) + except ValueError as e: + LOGGER.warning(f'Validation error for task {task["id"]}: {e}') + return_error_to_nextcloud(task['id'], e) + except Exception as e: + LOGGER.exception(f'Unexpected error processing task {task["id"]}', exc_info=e) + return_error_to_nextcloud(task['id'], e) + + except Exception as e: + LOGGER.exception('Error in task fetcher loop', exc_info=e) + wait_for_tasks(TP_CHECK_INTERVAL_ON_ERROR) + +def trigger_handler(provider_id: str): + global TP_TRIGGER + LOGGER.debug('Task processing trigger received', extra={'provider_id': provider_id}) + TP_TRIGGER.set() + +def wait_for_tasks(interval = None): + global TP_TRIGGER + global TP_CHECK_INTERVAL + global TP_CHECK_INTERVAL_WITH_TRIGGER + actual_interval = TP_CHECK_INTERVAL if interval is None else interval + if TP_TRIGGER.wait(timeout=actual_interval): + TP_CHECK_INTERVAL = TP_CHECK_INTERVAL_WITH_TRIGGER + TP_TRIGGER.clear() + + +def enrich_sources(results: list[SearchResult], userId: str) -> list[str]: + nc = NextcloudApp() + data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/enrich_sources', json={'sources': results, 'userId': userId}) + sources = EnrichedSourceList.model_validate(data).sources + return [s.model_dump_json() for s in sources] + + +def return_result_to_nextcloud(task_id: int, userId: str, result: dict[str, Any]) -> bool: + """ + Return query result back to Nextcloud. + + Args: + result: dict[str, Any] + + Returns: + True if successful, False otherwise + """ + LOGGER.debug('Returning result to Nextcloud', extra={ + 'task_id': task_id, + 'result': result, + }) + + nc = NextcloudApp() + + try: + nc.providers.task_processing.report_result(task_id, result) + except (NextcloudException, RequestException, JSONDecodeError) as e: + LOGGER.error(f"Network error reporting task result {e}", exc_info=e) + return False + + return True + + +def return_error_to_nextcloud(task_id: int, e: Exception) -> bool: + """ + Return error result back to Nextcloud. + + Args: + task_id: Unique task identifier + e: error object + + Returns: + True if successful, False otherwise + """ + LOGGER.debug('Returning error to Nextcloud', exc_info=e) + + nc = NextcloudApp() + + if isinstance(e, ValueError): + message = "Validation error: " + str(e) + elif isinstance(e, ContextException): + message = "Context error" + str(e) + else: + message = "Unexpected error" + str(e) + + try: + nc.providers.task_processing.report_result(task_id, None, message) + except (NextcloudException, RequestException, JSONDecodeError) as e: + LOGGER.error(f"Network error reporting task result {e}", exc_info=e) + return False + + return True + + +def process_normal_task( + task: dict[str, Any], + vectordb_loader: VectorDBLoader, + llm: LLM, + app_config: TConfig, +) -> LLMOutput: + """ + Process a single query task. + + Args: + task: Task dictionary from fetch_query_tasks_from_nextcloud + vectordb_loader: Vector database loader instance + llm: Language model instance + app_config: Application configuration + + Returns: + LLMOutput with generated text and sources + + Raises: + Various exceptions from query execution + """ + user_id = task['userId'] + task_input = task['input'] + if task_input.get('scopeType') == 'none': + task_input['scopeType'] = None + + return exec_in_proc(target=process_context_query, + args=( + user_id, + vectordb_loader, + llm, + app_config, + task_input.get('prompt'), + CONTEXT_LIMIT, + task_input.get('scopeType'), + task_input.get('scopeList'), + app_config.llm[1].get('template'), + ) + ) + +def process_search_task( + task: dict[str, Any], + vectordb_loader: VectorDBLoader, +) -> list[SearchResult]: + """ + Process a single search task. + + Args: + task: Task dictionary from fetch_query_tasks_from_nextcloud + vectordb_loader: Vector database loader instance + + Returns: + list of Search results + + Raises: + Various exceptions from query execution + """ + user_id = task['userId'] + task_input = task['input'] + if task_input.get('scopeType') == 'none': + task_input['scopeType'] = None + + return exec_in_proc(target=do_doc_search, + args=( + user_id, + task_input.get('prompt'), + vectordb_loader, + CONTEXT_LIMIT, + task_input.get('scopeType'), + task_input.get('scopeList'), + ) + ) + + +def start_bg_threads(app_config: TConfig, app_enabled: Event): + if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL: + if ( + ThreadType.FILES_INDEXING in THREADS + or ThreadType.UPDATES_PROCESSING in THREADS + ): + LOGGER.info('Background threads already running, skipping start') + return + + THREAD_STOP_EVENT.clear() + THREADS[ThreadType.FILES_INDEXING] = Thread( + target=files_indexing_thread, + args=(app_config, app_enabled), + name='FilesIndexingThread', + ) + THREADS[ThreadType.UPDATES_PROCESSING] = Thread( + target=updates_processing_thread, + args=(app_config, app_enabled), + name='UpdatesProcessingThread', + ) + THREADS[ThreadType.FILES_INDEXING].start() + THREADS[ThreadType.UPDATES_PROCESSING].start() + + if APP_ROLE == AppRole.RP or APP_ROLE == AppRole.NORMAL: + if ThreadType.REQUEST_PROCESSING in THREADS: + LOGGER.info('Background threads already running, skipping start') + return + + THREAD_STOP_EVENT.clear() + THREADS[ThreadType.REQUEST_PROCESSING] = Thread( + target=request_processing_thread, + args=(app_config, app_enabled), + name='RequestProcessingThread', + ) + THREADS[ThreadType.REQUEST_PROCESSING].start() + + +def wait_for_bg_threads(): + if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL: + if (ThreadType.FILES_INDEXING not in THREADS or ThreadType.UPDATES_PROCESSING not in THREADS): + return + + THREAD_STOP_EVENT.set() + THREADS[ThreadType.FILES_INDEXING].join() + THREADS[ThreadType.UPDATES_PROCESSING].join() + THREADS.pop(ThreadType.FILES_INDEXING) + THREADS.pop(ThreadType.UPDATES_PROCESSING) + + if APP_ROLE == AppRole.RP or APP_ROLE == AppRole.NORMAL: + if (ThreadType.REQUEST_PROCESSING not in THREADS): + return + + THREAD_STOP_EVENT.set() + THREADS[ThreadType.REQUEST_PROCESSING].join() + THREADS.pop(ThreadType.REQUEST_PROCESSING) diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 500a97d0..2694998d 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -2,7 +2,17 @@ # SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # -from pydantic import BaseModel +import re +from collections.abc import Mapping +from enum import Enum +from io import BytesIO +from typing import Annotated, Literal, Self + +import niquests +from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator, model_validator + +from .mimetype_list import SUPPORTED_MIMETYPES +from .vectordb.types import UpdateAccessOp __all__ = [ 'DEFAULT_EM_MODEL_ALIAS', @@ -15,6 +25,65 @@ ] DEFAULT_EM_MODEL_ALIAS = 'em_model' +FILES_PROVIDER_ID = 'files__default' + + +def is_valid_source_id(source_id: str) -> bool: + # note the ":" in the item id part + return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None + + +def is_valid_provider_id(provider_id: str) -> bool: + return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None + + +def _validate_source_ids(source_ids: list[str]) -> list[str]: + if ( + not isinstance(source_ids, list) + or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids) + or len(source_ids) == 0 + ): + raise ValueError('sourceIds must be a non-empty list of non-empty strings') + return [sid.strip() for sid in source_ids] + + +def _validate_source_id(source_id: str) -> str: + return _validate_source_ids([source_id])[0] + + +def _validate_provider_id(provider_id: str) -> str: + if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id): + raise ValueError('providerId must be a valid provider ID string') + return provider_id + + +def _validate_user_ids(user_ids: list[str]) -> list[str]: + if ( + not isinstance(user_ids, list) + or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids) + or len(user_ids) == 0 + ): + raise ValueError('userIds must be a non-empty list of non-empty strings') + return [uid.strip() for uid in user_ids] + + +def _validate_user_id(user_id: str) -> str: + return _validate_user_ids([user_id])[0] + + +def _get_file_id_from_source_ref(source_ref: str) -> int: + ''' + source reference is in the format "FILES_PROVIDER_ID: ". + ''' + if not source_ref.startswith(f'{FILES_PROVIDER_ID}: '): + raise ValueError(f'Source reference does not start with expected prefix: {source_ref}') + + try: + return int(source_ref[len(f'{FILES_PROVIDER_ID}: '):]) + except ValueError as e: + raise ValueError( + f'Invalid source reference format for extracting file_id: {source_ref}' + ) from e class TEmbeddingAuthApiKey(BaseModel): @@ -36,14 +105,17 @@ class TEmbeddingConfig(BaseModel): class TConfig(BaseModel): - debug: bool - uvicorn_log_level: str - disable_aaa: bool - verify_ssl: bool - use_colors: bool - uvicorn_workers: int - embedding_chunk_size: int - doc_parser_worker_limit: int + debug: bool = False + uvicorn_log_level: str = 'info' + disable_aaa: bool = False + verify_ssl: bool = True + use_colors: bool = True + uvicorn_workers: int = 1 + embedding_chunk_size: int = 2000 + doc_indexing_batch_size: int = 32 + actions_batch_size: int = 512 + file_parsing_cpu_count: int = -1 + concurrent_file_fetches: int = 10 vectordb: tuple[str, dict] embedding: TEmbeddingConfig @@ -55,7 +127,9 @@ class LoaderException(Exception): class EmbeddingException(Exception): - ... + def __init__(self, msg: str, response: niquests.Response | None = None): + super().__init__(msg) + self.response = response class RetryableEmbeddingException(EmbeddingException): """ @@ -71,3 +145,214 @@ class FatalEmbeddingException(EmbeddingException): Either malformed request, authentication error, or other non-retryable error. """ + +class DocErrorEmbeddingException(EmbeddingException): + """ + Exception that indicates a fatal error for the document, this document should not be retried. + """ + + +class AppRole(str, Enum): + NORMAL = 'normal' + INDEXING = 'indexing' + RP = 'rp' + + +class CommonSourceItem(BaseModel): + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + # source_id of the form "appId__providerId: itemId" + reference: Annotated[str, AfterValidator(_validate_source_id)] + title: str + modified: int + type: str + provider: Annotated[str, AfterValidator(_validate_provider_id)] + size: float + + @field_validator('modified', mode='before') + @classmethod + def validate_modified(cls, v): + if isinstance(v, int): + return v + if isinstance(v, str): + try: + return int(v) + except ValueError as e: + raise ValueError(f'Invalid modified value: {v}') from e + raise ValueError(f'Invalid modified type: {type(v)}') + + @field_validator('reference', 'title', 'type', 'provider') + @classmethod + def validate_strings_non_empty(cls, v): + if not isinstance(v, str) or v.strip() == '': + raise ValueError('Must be a non-empty string') + return v.strip() + + @field_validator('size') + @classmethod + def validate_size(cls, v): + if isinstance(v, int | float) and v >= 0: + return float(v) + raise ValueError(f'Invalid size value: {v}, must be a non-negative number') + + @model_validator(mode='after') + def validate_type(self) -> Self: + if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES: + raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}') + return self + + +class ReceivedFileItem(CommonSourceItem): + content: None + + @computed_field + @property + def file_id(self) -> int: + return _get_file_id_from_source_ref(self.reference) + + +class SourceItem(CommonSourceItem): + ''' + Used for the unified queue of items to process, after fetching the content for files + and for directly fetched content providers. + ''' + content: str | BytesIO + + @field_validator('content') + @classmethod + def validate_content(cls, v): + if isinstance(v, str): + if v.strip() == '': + raise ValueError('Content must be a non-empty string') + return v.strip() + if isinstance(v, BytesIO): + if v.getbuffer().nbytes == 0: + raise ValueError('Content must be a non-empty BytesIO') + return v + raise ValueError('Content must be either a non-empty string or a non-empty BytesIO') + + class Config: + # to allow BytesIO in content field + arbitrary_types_allowed = True + + +class FilesQueueItems(BaseModel): + files: Mapping[int, ReceivedFileItem] # [db id]: FileItem + content_providers: Mapping[int, SourceItem] # [db id]: SourceItem + + +class IndexingException(Exception): + retryable: bool = False + + def __init__(self, message: str, retryable: bool = False): + super().__init__(message) + self.retryable = retryable + + +class IndexingError(BaseModel): + error: str + retryable: bool = False + + +# PHP equivalent for reference: + +# class ActionType { +# // { sourceIds: array } +# public const DELETE_SOURCE_IDS = 'delete_source_ids'; +# // { providerId: string } +# public const DELETE_PROVIDER_ID = 'delete_provider_id'; +# // { userId: string } +# public const DELETE_USER_ID = 'delete_user_id'; +# // { op: string, userIds: array, sourceId: string } +# public const UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id'; +# // { op: string, userIds: array, providerId: string } +# public const UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id'; +# // { userIds: array, sourceId: string } +# public const UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id'; +# } + + +class ActionPayloadDeleteSourceIds(BaseModel): + sourceIds: Annotated[list[str], AfterValidator(_validate_source_ids)] + + +class ActionPayloadDeleteProviderId(BaseModel): + providerId: Annotated[str, AfterValidator(_validate_provider_id)] + + +class ActionPayloadDeleteUserId(BaseModel): + userId: Annotated[str, AfterValidator(_validate_user_id)] + + +class ActionPayloadUpdateAccessSourceId(BaseModel): + op: UpdateAccessOp + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + sourceId: Annotated[str, AfterValidator(_validate_source_id)] + + +class ActionPayloadUpdateAccessProviderId(BaseModel): + op: UpdateAccessOp + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + providerId: Annotated[str, AfterValidator(_validate_provider_id)] + + +class ActionPayloadUpdateAccessDeclSourceId(BaseModel): + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + sourceId: Annotated[str, AfterValidator(_validate_source_id)] + + +class ActionType(str, Enum): + DELETE_SOURCE_IDS = 'delete_source_ids' + DELETE_PROVIDER_ID = 'delete_provider_id' + DELETE_USER_ID = 'delete_user_id' + UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id' + UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id' + UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id' + + +class CommonActionsQueueItem(BaseModel): + id: int + + +class ActionsQueueItemDeleteSourceIds(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_SOURCE_IDS] + payload: ActionPayloadDeleteSourceIds + + +class ActionsQueueItemDeleteProviderId(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_PROVIDER_ID] + payload: ActionPayloadDeleteProviderId + + +class ActionsQueueItemDeleteUserId(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_USER_ID] + payload: ActionPayloadDeleteUserId + + +class ActionsQueueItemUpdateAccessSourceId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_SOURCE_ID] + payload: ActionPayloadUpdateAccessSourceId + + +class ActionsQueueItemUpdateAccessProviderId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_PROVIDER_ID] + payload: ActionPayloadUpdateAccessProviderId + + +class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_DECL_SOURCE_ID] + payload: ActionPayloadUpdateAccessDeclSourceId + + +ActionsQueueItem = Annotated[ + ActionsQueueItemDeleteSourceIds + | ActionsQueueItemDeleteProviderId + | ActionsQueueItemDeleteUserId + | ActionsQueueItemUpdateAccessSourceId + | ActionsQueueItemUpdateAccessProviderId + | ActionsQueueItemUpdateAccessDeclSourceId, + Discriminator('type'), +] + + +class ActionsQueueItems(BaseModel): + actions: Mapping[int, ActionsQueueItem] diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py index f6d6e672..c7939781 100644 --- a/context_chat_backend/utils.py +++ b/context_chat_backend/utils.py @@ -2,11 +2,16 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +import faulthandler +import io import logging import multiprocessing as mp -import re +import os +import signal +import sys import traceback from collections.abc import Callable +from contextlib import suppress from functools import partial, wraps from multiprocessing.connection import Connection from time import perf_counter_ns @@ -14,10 +19,11 @@ from fastapi.responses import JSONResponse as FastAPIJSONResponse -from .types import TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig +from .types import AppRole, TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig T = TypeVar('T') _logger = logging.getLogger('ccb.utils') +_MAX_STD_CAPTURE_CHARS = 64 * 1024 def not_none(value: T | None) -> TypeGuard[T]: @@ -69,19 +75,105 @@ def JSONResponse( return FastAPIJSONResponse(content, status_code, **kwargs) -def exception_wrap(fun: Callable | None, *args, resconn: Connection, **kwargs): - try: - if fun is None: - return resconn.send({ 'value': None, 'error': None }) - resconn.send({ 'value': fun(*args, **kwargs), 'error': None }) - except Exception as e: - tb = traceback.format_exc() - resconn.send({ 'value': None, 'error': e, 'traceback': tb }) +class SubprocessKilledError(RuntimeError): + """Raised when a subprocess is terminated by a signal (for example SIGKILL).""" + + def __init__(self, pid: int | None, target_name: str, exitcode: int): + super().__init__( + f'Subprocess PID {pid} for {target_name} exited with signal {abs(exitcode)} ' + f'(raw exit code: {exitcode})' + ) + self.exitcode = exitcode + + +class SubprocessExecutionError(RuntimeError): + """Raised when a subprocess exits without a recoverable Python exception payload.""" + + def __init__(self, pid: int | None, target_name: str, exitcode: int, details: str = ''): + msg = f'Subprocess PID {pid} for {target_name} exited with exit code {exitcode}' + if details: + msg = f'{msg}: {details}' + super().__init__(msg) + self.exitcode = exitcode + + +def _truncate_capture(text: str) -> str: + if len(text) <= _MAX_STD_CAPTURE_CHARS: + return text + + head = _MAX_STD_CAPTURE_CHARS // 2 + tail = _MAX_STD_CAPTURE_CHARS - head + omitted = len(text) - _MAX_STD_CAPTURE_CHARS + return ( + f'[truncated {omitted} chars]\n' + f'{text[:head]}\n' + '[...snip...]\n' + f'{text[-tail:]}' + ) + + +def exception_wrap(fun: Callable | None, *args, resconn: Connection, stdconn: Connection, **kwargs): + # ignore SIGINT and SIGTERM in child processes these signals don't immediately stop these processes + # the handling is done in the fastapi lifetime to do a graceful shutdown + # SIGKILL is not ignored + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + # Preserve real stderr FD for faulthandler before we redirect sys.stderr. + _faulthandler_fd = os.dup(2) + with suppress(Exception): + faulthandler.enable( + file=os.fdopen(_faulthandler_fd, 'w', closefd=False), + all_threads=True, + ) + stdout_capture = io.StringIO() + stderr_capture = io.StringIO() + orig_stdout = sys.stdout + orig_stderr = sys.stderr + sys.stdout = stdout_capture + sys.stderr = stderr_capture -def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): # noqa: B006 + try: + value = None if fun is None else fun(*args, **kwargs) + try: + resconn.send({ 'value': value, 'error': None }) + except (BrokenPipeError, OSError, EOFError): + ... # parent closed the pipe during shutdown, exit cleanly + except BaseException as e: + tb = traceback.format_exc() + payload = { + 'value': None, + 'error': e, + 'traceback': tb, + } + try: + resconn.send(payload) + except Exception as send_err: + stderr_capture.write(f'Original error: {e}, pipe send error: {send_err}') + finally: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + stdout_text = _truncate_capture(stdout_capture.getvalue()) + stderr_text = _truncate_capture(stderr_capture.getvalue()) + with suppress(Exception): + stdconn.send({ + 'stdout': stdout_text, + 'stderr': stderr_text, + }) + with suppress(Exception): + os.close(_faulthandler_fd) + + +def exec_in_proc(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None): + if not kwargs: + kwargs = {} + + # parent, child pconn, cconn = mp.Pipe() + std_pconn, std_cconn = mp.Pipe() kwargs['resconn'] = cconn + kwargs['stdconn'] = std_cconn p = mp.Process( group=group, target=partial(exception_wrap, target), @@ -90,24 +182,92 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem kwargs=kwargs, daemon=daemon, ) + target_name = getattr(target, '__name__', str(target)) + start = perf_counter_ns() p.start() + _logger.debug('Subprocess PID %d started for %s', p.pid, target_name) + + result = None + stdobj = { 'stdout': '', 'stderr': '' } + got_result = False + got_std = False + + # Drain result/std pipes while child is still alive to avoid deadlock on full pipe buffers. + # Pipe's buffer size is 64 KiB + while p.is_alive() and (not got_result or not got_std): + if not got_result and pconn.poll(0.1): + with suppress(EOFError, OSError, BrokenPipeError): + result = pconn.recv() + got_result = True + if not got_std and std_pconn.poll(): + with suppress(EOFError, OSError, BrokenPipeError): + stdobj = std_pconn.recv() + got_std = True + p.join() + elapsed_ms = (perf_counter_ns() - start) / 1e6 + _logger.debug( + 'Subprocess PID %d for %s finished in %.2f ms (exit code: %s)', + p.pid, target_name, elapsed_ms, p.exitcode, + ) - result = pconn.recv() - if result['error'] is not None: - _logger.error('original traceback: %s', result['traceback']) + if not got_std: + with suppress(EOFError, OSError, BrokenPipeError): + if std_pconn.poll(): + stdobj = std_pconn.recv() + # no need to update got_std here + if stdobj.get('stdout') or stdobj.get('stderr'): + _logger.info('std info for %s', target_name, extra={ + 'stdout': stdobj.get('stdout', ''), + 'stderr': stdobj.get('stderr', ''), + }) + + if not got_result: + with suppress(EOFError, OSError, BrokenPipeError): + if pconn.poll(): + result = pconn.recv() + # no need to update got_result here + + if result is not None and result.get('error') is not None: + _logger.error( + 'original traceback of %s (PID %d, exitcode: %s): %s', + target_name, + p.pid, + p.exitcode, + result.get('traceback', ''), + ) raise result['error'] - return result['value'] - - -def is_valid_source_id(source_id: str) -> bool: - # note the ":" in the item id part - return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None + if result is not None and 'value' in result: + if p.exitcode not in (None, 0): + _logger.warning( + 'Subprocess PID %d for %s exited with code %s after %.2f ms' + ' but returned a valid result', + p.pid, target_name, p.exitcode, elapsed_ms, + ) + return result['value'] + if p.exitcode and p.exitcode < 0: + _logger.warning( + 'Subprocess PID %d for %s exited due to signal %d, exitcode %d after %.2f ms', + p.pid, target_name, abs(p.exitcode), p.exitcode, elapsed_ms, + ) + raise SubprocessKilledError(p.pid, target_name, p.exitcode) + + if p.exitcode not in (None, 0): + raise SubprocessExecutionError( + p.pid, + target_name, + p.exitcode, + f'No structured exception payload received from child process: {result}', + ) -def is_valid_provider_id(provider_id: str) -> bool: - return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None + raise SubprocessExecutionError( + p.pid, + target_name, + 0, + f'Subprocess exited successfully but returned no result payload: {result}', + ) def timed(func: Callable): @@ -144,3 +304,13 @@ def redact_config(config: TConfig | TEmbeddingConfig) -> TConfig | TEmbeddingCon em_conf.auth.password = '***REDACTED***' # noqa: S105 return config_copy + + +def get_app_role() -> AppRole: + role = os.getenv('APP_ROLE', '').lower() + if role == '': + return AppRole.NORMAL + if role not in ['indexing', 'rp']: + _logger.warning(f'Invalid app role: {role}, defaulting to all roles') + return AppRole.NORMAL + return AppRole(role) diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py index 0bf10200..2b4aa35e 100644 --- a/context_chat_backend/vectordb/base.py +++ b/context_chat_backend/vectordb/base.py @@ -3,14 +3,15 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # from abc import ABC, abstractmethod +from collections.abc import Mapping from typing import Any -from fastapi import UploadFile from langchain.schema import Document from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore from ..chain.types import InDocument, ScopeType +from ..types import IndexingError, ReceivedFileItem, SourceItem from ..utils import timed from .types import UpdateAccessOp @@ -62,7 +63,7 @@ def get_instance(self) -> VectorStore: ''' @abstractmethod - def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list[str]]: + def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]: ''' Adds the given indocuments to the vectordb and updates the docs + access tables. @@ -79,10 +80,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list @timed @abstractmethod - def check_sources( - self, - sources: list[UploadFile], - ) -> tuple[list[str], list[str]]: + def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]: ''' Checks the sources in the vectordb if they are already embedded and are up to date. @@ -91,8 +89,8 @@ def check_sources( Args ---- - sources: list[UploadFile] - List of source ids to check. + sources: Mapping[int, SourceItem | ReceivedFileItem] + Dict of sources to check. Returns ------- diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py index 2b7fc060..9d880245 100644 --- a/context_chat_backend/vectordb/pgvector.py +++ b/context_chat_backend/vectordb/pgvector.py @@ -4,21 +4,30 @@ # import logging import os +from collections.abc import Mapping from datetime import datetime +from time import perf_counter_ns import psycopg import sqlalchemy as sa import sqlalchemy.dialects.postgresql as postgresql_dialects import sqlalchemy.orm as orm from dotenv import load_dotenv -from fastapi import UploadFile from langchain.schema import Document from langchain.vectorstores import VectorStore from langchain_core.embeddings import Embeddings from langchain_postgres.vectorstores import Base, PGVector from ..chain.types import InDocument, ScopeType -from ..types import EmbeddingException, RetryableEmbeddingException +from ..types import ( + DocErrorEmbeddingException, + EmbeddingException, + FatalEmbeddingException, + IndexingError, + ReceivedFileItem, + RetryableEmbeddingException, + SourceItem, +) from ..utils import timed from .base import BaseVectorDB from .types import DbException, SafeDbException, UpdateAccessOp @@ -112,7 +121,15 @@ def __init__(self, embedding: Embeddings | None = None, **kwargs): kwargs['connection'] = os.environ['CCB_DB_URL'] # setup langchain db + our access list table - self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs) + try: + self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs) + except sa.exc.IntegrityError as ie: # pyright: ignore[reportAttributeAccessIssue] + if not isinstance(ie.orig, psycopg.errors.UniqueViolation): + raise + + # tried to create the tables but it was already created in another process + # init the client again to detect it already exists, and continue from there + self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs) def get_instance(self) -> VectorStore: return self.client @@ -130,24 +147,40 @@ def get_users(self) -> list[str]: except Exception as e: raise DbException('Error: getting a list of all users from access list') from e - def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]: + def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]: """ Raises EmbeddingException: if the embedding request definitively fails """ - added_sources = [] - retry_sources = [] + results = {} batch_size = PG_BATCH_SIZE // 5 with self.session_maker() as session: - for indoc in indocuments: + for php_db_id, indoc in indocuments.items(): try: # query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html) # so we chunk the documents into (5 values * 10k) chunks # change the chunk size when there are more inserted values per document chunk_ids = [] - for i in range(0, len(indoc.documents), batch_size): + total_chunks = len(indoc.documents) + num_batches = max(1, -(-total_chunks // batch_size)) # ceiling division + logger.debug( + 'Embedding source %s: %d chunk(s) in %d batch(es)', + indoc.source_id, total_chunks, num_batches, + ) + for i in range(0, total_chunks, batch_size): + batch_num = i // batch_size + 1 + logger.debug( + 'Sending embedding batch %d/%d (%d chunk(s)) for source %s', + batch_num, num_batches, len(indoc.documents[i:i+batch_size]), indoc.source_id, + ) + t0 = perf_counter_ns() chunk_ids.extend(self.client.add_documents(indoc.documents[i:i+batch_size])) + elapsed_ms = (perf_counter_ns() - t0) / 1e6 + logger.debug( + 'Embedding batch %d/%d for source %s completed in %.2f ms', + batch_num, num_batches, indoc.source_id, elapsed_ms, + ) doc = DocumentsStore( source_id=indoc.source_id, @@ -170,7 +203,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis ) self.decl_update_access(indoc.userIds, indoc.source_id, session) - added_sources.append(indoc.source_id) + results[php_db_id] = None session.commit() except SafeDbException as e: # for when the source_id is not found. This here can be an error in the DB @@ -178,51 +211,73 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error=str(e), + retryable=True, + ) + continue + except DocErrorEmbeddingException as e: + logger.warning( + 'Error adding documents to vectordb, server failed to index it, it will not be retried', + exc_info=e, + extra={ 'source_id': indoc.source_id }, + ) + results[php_db_id] = IndexingError( + error=str(e), + retryable=False, + ) continue - except RetryableEmbeddingException as e: + except FatalEmbeddingException as e: + raise EmbeddingException( + f'Fatal error while embedding documents for source {indoc.source_id}: {e}' + ) from e + except (RetryableEmbeddingException, EmbeddingException) as e: # temporary error, continue with the next document - logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={ + logger.warning('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error=str(e), + retryable=True, + ) continue - except EmbeddingException as e: - logger.exception('Error adding documents to vectordb', exc_info=e, extra={ - 'source_id': indoc.source_id, - }) - raise except Exception as e: logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error='An unexpected error occurred while adding documents to the database.', + retryable=True, + ) continue - return added_sources, retry_sources + return results @timed - def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]]: + def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]: + ''' + returns a tuple of (existing_source_ids, to_embed_source_ids) + ''' with self.session_maker() as session: try: stmt = ( sa.select(DocumentsStore.source_id) - .filter(DocumentsStore.source_id.in_([source.filename for source in sources])) + .filter(DocumentsStore.source_id.in_([source.reference for source in sources.values()])) .with_for_update() ) results = session.execute(stmt).fetchall() existing_sources = {r.source_id for r in results} - to_embed = [source.filename for source in sources if source.filename not in existing_sources] + to_embed = [source.reference for source in sources.values() if source.reference not in existing_sources] to_delete = [] - for source in sources: + for source in sources.values(): stmt = ( sa.select(DocumentsStore.source_id) - .filter(DocumentsStore.source_id == source.filename) + .filter(DocumentsStore.source_id == source.reference) .filter(DocumentsStore.modified < sa.cast( - datetime.fromtimestamp(int(source.headers['modified'])), + datetime.fromtimestamp(int(source.modified)), sa.DateTime, )) ) @@ -239,14 +294,13 @@ def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str] session.rollback() raise DbException('Error: checking sources in vectordb') from e - still_existing_sources = [ - source - for source in existing_sources - if source not in to_delete + still_existing_source_ids = [ + source_id + for source_id in existing_sources + if source_id not in to_delete ] - # the pyright issue stems from source.filename, which has already been validated - return list(still_existing_sources), to_embed # pyright: ignore[reportReturnType] + return list(still_existing_source_ids), to_embed def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.Session | None = None): session = session_ or self.session_maker() @@ -325,7 +379,7 @@ def update_access( ) match op: - case UpdateAccessOp.allow: + case UpdateAccessOp.ALLOW: for i in range(0, len(user_ids), PG_BATCH_SIZE): batched_uids = user_ids[i:i+PG_BATCH_SIZE] stmt = ( @@ -342,7 +396,7 @@ def update_access( session.execute(stmt) session.commit() - case UpdateAccessOp.deny: + case UpdateAccessOp.DENY: for i in range(0, len(user_ids), PG_BATCH_SIZE): batched_uids = user_ids[i:i+PG_BATCH_SIZE] stmt = ( @@ -435,15 +489,17 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None # entry from "AccessListStore" is deleted automatically due to the foreign key constraint # batch the deletion to avoid hitting the query parameter limit chunks_to_delete = [] + deleted_source_ids = [] for i in range(0, len(source_ids), PG_BATCH_SIZE): batched_ids = source_ids[i:i+PG_BATCH_SIZE] stmt_doc = ( sa.delete(DocumentsStore) .filter(DocumentsStore.source_id.in_(batched_ids)) - .returning(DocumentsStore.chunks) + .returning(DocumentsStore.chunks, DocumentsStore.source_id) ) doc_result = session.execute(stmt_doc) chunks_to_delete.extend(str(c) for res in doc_result for c in res.chunks) + deleted_source_ids.extend(str(res.source_id) for res in doc_result) for i in range(0, len(chunks_to_delete), PG_BATCH_SIZE): batched_chunks = chunks_to_delete[i:i+PG_BATCH_SIZE] @@ -463,6 +519,14 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None if session_ is None: session.close() + undeleted_source_ids = set(source_ids) - set(deleted_source_ids) + if len(undeleted_source_ids) > 0: + logger.info( + f'Source ids {undeleted_source_ids} were not deleted from documents store.' + ' This can be due to the source ids not existing in the documents store due to' + ' already being deleted or not having been added yet.' + ) + def delete_provider(self, provider_key: str): with self.session_maker() as session: try: @@ -506,7 +570,16 @@ def delete_user(self, user_id: str): session.rollback() raise DbException('Error: deleting user from access list') from e - self._cleanup_if_orphaned(list(source_ids), session) + try: + self._cleanup_if_orphaned(list(source_ids), session) + except Exception as e: + session.rollback() + logger.error( + 'Error cleaning up orphaned source ids after deleting user, manual cleanup might be required', + exc_info=e, + extra={ 'source_ids': list(source_ids) }, + ) + raise DbException('Error: cleaning up orphaned source ids after deleting user') from e def count_documents_by_provider(self) -> dict[str, int]: try: @@ -554,6 +627,8 @@ def doc_search( # get embeddings return self._similarity_search(session, query, chunk_ids, k) + except EmbeddingException: + raise except Exception as e: raise DbException('Error: performing doc search in vectordb') from e diff --git a/context_chat_backend/vectordb/service.py b/context_chat_backend/vectordb/service.py index 620a0b39..06a8e19e 100644 --- a/context_chat_backend/vectordb/service.py +++ b/context_chat_backend/vectordb/service.py @@ -6,27 +6,42 @@ from ..dyn_loader import VectorDBLoader from .base import BaseVectorDB -from .types import DbException, UpdateAccessOp +from .types import UpdateAccessOp logger = logging.getLogger('ccb.vectordb') -# todo: return source ids that were successfully deleted + def delete_by_source(vectordb_loader: VectorDBLoader, source_ids: list[str]): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('deleting sources by id', extra={ 'source_ids': source_ids }) - try: - db.delete_source_ids(source_ids) - except Exception as e: - raise DbException('Error: Vectordb delete_source_ids error') from e + db.delete_source_ids(source_ids) def delete_by_provider(vectordb_loader: VectorDBLoader, provider_key: str): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug(f'deleting sources by provider: {provider_key}') db.delete_provider(provider_key) def delete_user(vectordb_loader: VectorDBLoader, user_id: str): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug(f'deleting user from db: {user_id}') db.delete_user(user_id) @@ -38,6 +53,13 @@ def update_access( user_ids: list[str], source_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('updating access', extra={ 'op': op, 'user_ids': user_ids, 'source_id': source_id }) db.update_access(op, user_ids, source_id) @@ -49,6 +71,13 @@ def update_access_provider( user_ids: list[str], provider_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('updating access by provider', extra={ 'op': op, 'user_ids': user_ids, 'provider_id': provider_id }) db.update_access_provider(op, user_ids, provider_id) @@ -59,11 +88,24 @@ def decl_update_access( user_ids: list[str], source_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('decl update access', extra={ 'user_ids': user_ids, 'source_id': source_id }) db.decl_update_access(user_ids, source_id) def count_documents_by_provider(vectordb_loader: VectorDBLoader): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('counting documents by provider') return db.count_documents_by_provider() diff --git a/context_chat_backend/vectordb/types.py b/context_chat_backend/vectordb/types.py index df5c6dd7..30811797 100644 --- a/context_chat_backend/vectordb/types.py +++ b/context_chat_backend/vectordb/types.py @@ -14,5 +14,5 @@ class SafeDbException(Exception): class UpdateAccessOp(Enum): - allow = 'allow' - deny = 'deny' + ALLOW = 'allow' + DENY = 'deny' diff --git a/main.py b/main.py index c4ffa1fd..076b7db0 100755 --- a/main.py +++ b/main.py @@ -3,9 +3,12 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # + import logging -from os import getenv +import multiprocessing as mp +from os import cpu_count, getenv +import psutil import uvicorn from nc_py_api.ex_app import run_app @@ -48,6 +51,18 @@ def _setup_log_levels(debug: bool): app_config: TConfig = app.extra['CONFIG'] _setup_log_levels(app_config.debug) + # do forks from a clean process that doesn't have any threads or locks + mp.set_start_method('forkserver') + mp.set_forkserver_preload([ + 'context_chat_backend.chain.ingest.injest', + 'context_chat_backend.vectordb.pgvector', + 'langchain', + 'logging', + 'numpy', + 'sqlalchemy', + ]) + + print(f'CPU count: {cpu_count()}, Memory: {psutil.virtual_memory()}') print('App config:\n' + redact_config(app_config).model_dump_json(indent=2), flush=True) uv_log_config = uvicorn.config.LOGGING_CONFIG # pyright: ignore[reportAttributeAccessIssue]