Skip to content

Commit b42716a

Browse files
committed
chore: move file fetch inside injest
Signed-off-by: Anupam Kumar <kyteinsky@gmail.com>
1 parent c63ddfa commit b42716a

5 files changed

Lines changed: 208 additions & 194 deletions

File tree

context_chat_backend/chain/ingest/injest.py

Lines changed: 179 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
33
# SPDX-License-Identifier: AGPL-3.0-or-later
44
#
5+
import asyncio
56
import logging
67
import re
8+
from collections.abc import Mapping
9+
from io import BytesIO
710

11+
import niquests
812
from langchain.schema import Document
13+
from nc_py_api import AsyncNextcloudApp
914

1015
from ...dyn_loader import VectorDBLoader
11-
from ...types import IndexingError, IndexingException, SourceItem, TConfig
16+
from ...types import IndexingError, IndexingException, ReceivedFileItem, SourceItem, TConfig
1217
from ...vectordb.base import BaseVectorDB
1318
from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp
1419
from ..types import InDocument
@@ -17,15 +22,165 @@
1722

1823
logger = logging.getLogger('ccb.injest')
1924

25+
# max concurrent fetches to avoid overloading the NC server or hitting rate limits
26+
CONCURRENT_FILE_FETCHES = 10 # todo: config?
27+
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, all loaded in RAM at once, todo: config?
28+
29+
30+
async def __fetch_file_content(
31+
semaphore: asyncio.Semaphore,
32+
file_id: int,
33+
user_id: str,
34+
_rlimit = 3,
35+
) -> BytesIO:
36+
'''
37+
Raises
38+
------
39+
IndexingException
40+
'''
41+
42+
async with semaphore:
43+
nc = AsyncNextcloudApp()
44+
try:
45+
# a file pointer for storing the stream in memory until it is consumed
46+
fp = BytesIO()
47+
await nc._session.download2fp(
48+
url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}',
49+
fp=fp,
50+
dav=False,
51+
params={ 'userId': user_id },
52+
)
53+
return fp
54+
except niquests.exceptions.RequestException as e:
55+
if e.response is None:
56+
raise
57+
58+
if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue]
59+
# todo: implement rate limits in php CC?
60+
wait_for = int(e.response.headers.get('Retry-After', '30'))
61+
if _rlimit <= 0:
62+
raise IndexingException(
63+
f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
64+
' max retries exceeded',
65+
retryable=True,
66+
) from e
67+
logger.warning(
68+
f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
69+
f' waiting {wait_for} before retrying',
70+
exc_info=e,
71+
)
72+
await asyncio.sleep(wait_for)
73+
return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1)
74+
75+
raise
76+
except IndexingException:
77+
raise
78+
except Exception as e:
79+
logger.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e)
80+
raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e
81+
82+
83+
async def __fetch_files_content(
84+
sources: Mapping[int, SourceItem | ReceivedFileItem]
85+
) -> tuple[Mapping[int, SourceItem], Mapping[int, IndexingError]]:
86+
source_items = {}
87+
error_items = {}
88+
semaphore = asyncio.Semaphore(CONCURRENT_FILE_FETCHES)
89+
tasks = []
90+
91+
for db_id, file in sources.items():
92+
if isinstance(file, SourceItem):
93+
continue
94+
95+
try:
96+
# to detect any validation errors but it should not happen since file.reference is validated
97+
file.file_id # noqa: B018
98+
except ValueError as e:
99+
logger.error(
100+
f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}',
101+
exc_info=e,
102+
)
103+
error_items[db_id] = IndexingError(
104+
error=f'Invalid file reference format: {file.reference}',
105+
retryable=False,
106+
)
107+
continue
108+
109+
if file.size > MAX_FILE_SIZE:
110+
logger.info(
111+
f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size'
112+
f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB',
113+
)
114+
error_items[db_id] = IndexingError(
115+
error=(
116+
f'File size {(file.size/(1024*1024)):.2f} MiB'
117+
f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB'
118+
),
119+
retryable=False,
120+
)
121+
continue
122+
# any user id from the list should have read access to the file
123+
tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0])))
124+
125+
results = await asyncio.gather(*tasks, return_exceptions=True)
126+
for (db_id, file), result in zip(sources.items(), results, strict=True):
127+
if isinstance(file, SourceItem):
128+
continue
129+
130+
if isinstance(result, IndexingException):
131+
logger.error(
132+
f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
133+
f': {result}',
134+
exc_info=result,
135+
)
136+
error_items[db_id] = IndexingError(
137+
error=str(result),
138+
retryable=result.retryable,
139+
)
140+
elif isinstance(result, str) or isinstance(result, BytesIO):
141+
source_items[db_id] = SourceItem(
142+
**{
143+
**file.model_dump(),
144+
'content': result,
145+
}
146+
)
147+
elif isinstance(result, BaseException):
148+
logger.error(
149+
f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},'
150+
f' reference {file.reference}: {result}',
151+
exc_info=result,
152+
)
153+
error_items[db_id] = IndexingError(
154+
error=f'Unexpected error: {result}',
155+
retryable=True,
156+
)
157+
else:
158+
logger.error(
159+
f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
160+
f': {result}',
161+
exc_info=True,
162+
)
163+
error_items[db_id] = IndexingError(
164+
error='Unknown error',
165+
retryable=True,
166+
)
167+
168+
# add the content providers from the orginal "sources" to the result unprocessed
169+
for db_id, source in sources.items():
170+
if isinstance(source, SourceItem):
171+
source_items[db_id] = source
172+
173+
return source_items, error_items
174+
20175

21176
def _filter_sources(
22177
vectordb: BaseVectorDB,
23-
sources: dict[int, SourceItem]
24-
) -> tuple[dict[int, SourceItem], dict[int, SourceItem]]:
178+
sources: Mapping[int, SourceItem | ReceivedFileItem]
179+
) -> tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]:
25180
'''
26181
Returns
27182
-------
28-
tuple[list[str], list[UploadFile]]
183+
tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]:
29184
First value is a list of sources that already exist in the vectordb.
30185
Second value is a list of sources that are new and should be embedded.
31186
'''
@@ -49,15 +204,14 @@ def _filter_sources(
49204

50205
def _sources_to_indocuments(
51206
config: TConfig,
52-
sources: dict[int, SourceItem]
53-
) -> tuple[dict[int, InDocument], dict[int, IndexingError]]:
207+
sources: Mapping[int, SourceItem]
208+
) -> tuple[Mapping[int, InDocument], Mapping[int, IndexingError]]:
54209
indocuments = {}
55210
errored_docs = {}
56211

57212
for db_id, source in sources.items():
58213
logger.debug('processing source', extra={ 'source_id': source.reference })
59214

60-
# todo: maybe fetch the content of the files here
61215
# transform the source to have text data
62216
try:
63217
content = decode_source(source)
@@ -121,8 +275,8 @@ def _sources_to_indocuments(
121275

122276
def _increase_access_for_existing_sources(
123277
vectordb: BaseVectorDB,
124-
existing_sources: dict[int, SourceItem]
125-
) -> dict[int, IndexingError | None]:
278+
existing_sources: Mapping[int, SourceItem | ReceivedFileItem]
279+
) -> Mapping[int, IndexingError | None]:
126280
'''
127281
update userIds for existing sources
128282
allow the userIds as additional users, not as the only users
@@ -162,8 +316,8 @@ def _increase_access_for_existing_sources(
162316
def _process_sources(
163317
vectordb: BaseVectorDB,
164318
config: TConfig,
165-
sources: dict[int, SourceItem]
166-
) -> dict[int, IndexingError | None]:
319+
sources: Mapping[int, SourceItem | ReceivedFileItem]
320+
) -> Mapping[int, IndexingError | None]:
167321
'''
168322
Processes the sources and adds them to the vectordb.
169323
Returns the list of source ids that were successfully added and those that need to be retried.
@@ -178,27 +332,34 @@ def _process_sources(
178332

179333
source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources)
180334

181-
if len(to_embed_sources) == 0:
335+
populated_to_embed_sources, errored_sources = asyncio.run(__fetch_files_content(to_embed_sources))
336+
source_proc_results.update(errored_sources) # pyright: ignore[reportAttributeAccessIssue]
337+
338+
if len(populated_to_embed_sources) == 0:
182339
# no new sources to embed
183340
logger.debug('Filtered all sources, nothing to embed')
184341
return source_proc_results
185342

186343
logger.debug('Filtered sources:', extra={
187-
'source_ids': [source.reference for source in to_embed_sources.values()]
344+
'source_ids': [source.reference for source in populated_to_embed_sources.values()]
188345
})
189346
# invalid/empty sources are filtered out here and not counted in loaded/retryable
190-
indocuments, errored_docs = _sources_to_indocuments(config, to_embed_sources)
347+
indocuments, errored_docs = _sources_to_indocuments(config, populated_to_embed_sources)
191348

192-
source_proc_results.update(errored_docs)
349+
source_proc_results.update(errored_docs) # pyright: ignore[reportAttributeAccessIssue]
193350
logger.debug('Converted sources to documents')
194351

195352
if len(indocuments) == 0:
196353
# filtered document(s) were invalid/empty, not an error
197354
logger.debug('All documents were found empty after being processed')
198355
return source_proc_results
199356

357+
logger.debug('Adding documents to vectordb', extra={
358+
'source_ids': [indoc.source_id for indoc in indocuments.values()]
359+
})
360+
200361
doc_add_results = vectordb.add_indocuments(indocuments)
201-
source_proc_results.update(doc_add_results)
362+
source_proc_results.update(doc_add_results) # pyright: ignore[reportAttributeAccessIssue]
202363
logger.debug('Added documents to vectordb')
203364

204365
return source_proc_results
@@ -215,8 +376,8 @@ def _decode_latin_1(s: str) -> str:
215376
def embed_sources(
216377
vectordb_loader: VectorDBLoader,
217378
config: TConfig,
218-
sources: dict[int, SourceItem]
219-
) -> dict[int, IndexingError | None]:
379+
sources: Mapping[int, SourceItem | ReceivedFileItem]
380+
) -> Mapping[int, IndexingError | None]:
220381
logger.debug('Embedding sources:', extra={
221382
'source_ids': [
222383
f'{source.reference} ({_decode_latin_1(source.title)})'

0 commit comments

Comments
 (0)