|
23 | 23 | from sqlalchemy import select |
24 | 24 | from sqlalchemy.ext.asyncio import AsyncSession |
25 | 25 |
|
26 | | -from aperag.config import get_async_session, settings |
| 26 | +from aperag.config import settings |
27 | 27 | from aperag.db import models as db_models |
28 | 28 | from aperag.db.ops import AsyncDatabaseOps, async_db_ops |
29 | 29 | from aperag.docparser.doc_parser import DocParser |
@@ -79,52 +79,95 @@ def __init__(self, session: AsyncSession = None): |
79 | 79 | else: |
80 | 80 | self.db_ops = AsyncDatabaseOps(session) # Create custom instance for transaction control |
81 | 81 |
|
82 | | - async def build_document_response( |
83 | | - self, document: db_models.Document, session: AsyncSession |
84 | | - ) -> view_models.Document: |
85 | | - """Build Document response object for API return using new status model.""" |
86 | | - from sqlalchemy import select |
87 | | - |
88 | | - from aperag.db.models import DocumentIndex |
| 82 | + async def _query_documents_with_indexes( |
| 83 | + self, user: str, collection_id: str, document_id: str = None |
| 84 | + ) -> List[db_models.Document]: |
| 85 | + """ |
| 86 | + Common function to query documents with their indexes using JOIN. |
| 87 | + If document_id is provided, query single document, otherwise query all documents. |
| 88 | + """ |
89 | 89 |
|
90 | | - # Get all document indexes for status calculation |
91 | | - document_indexes = await session.execute( |
92 | | - select(DocumentIndex).where( |
93 | | - DocumentIndex.document_id == document.id, |
94 | | - DocumentIndex.status != db_models.DocumentIndexStatus.DELETING, |
95 | | - DocumentIndex.status != db_models.DocumentIndexStatus.DELETION_IN_PROGRESS, |
| 90 | + async def _execute_query(session): |
| 91 | + from sqlalchemy import and_, outerjoin, select |
| 92 | + |
| 93 | + # Create JOIN query between Document and DocumentIndex tables |
| 94 | + # Use outerjoin to get all documents even if they don't have indexes |
| 95 | + query = ( |
| 96 | + select( |
| 97 | + db_models.Document, |
| 98 | + db_models.DocumentIndex.index_type, |
| 99 | + db_models.DocumentIndex.status.label("index_status"), |
| 100 | + db_models.DocumentIndex.gmt_created.label("index_created_at"), |
| 101 | + db_models.DocumentIndex.gmt_updated.label("index_updated_at"), |
| 102 | + db_models.DocumentIndex.error_message.label("index_error_message"), |
| 103 | + ) |
| 104 | + .select_from( |
| 105 | + outerjoin( |
| 106 | + db_models.Document, |
| 107 | + db_models.DocumentIndex, |
| 108 | + db_models.Document.id == db_models.DocumentIndex.document_id, |
| 109 | + ) |
| 110 | + ) |
| 111 | + .where( |
| 112 | + and_( |
| 113 | + db_models.Document.user == user, |
| 114 | + db_models.Document.collection_id == collection_id, |
| 115 | + db_models.Document.status != db_models.DocumentStatus.DELETED, |
| 116 | + ) |
| 117 | + ) |
| 118 | + .order_by(db_models.Document.gmt_created.desc()) |
96 | 119 | ) |
97 | | - ) |
98 | | - indexes = document_indexes.scalars().all() |
99 | | - |
100 | | - # Map index states to API response format |
101 | | - index_status = {} |
102 | | - index_updated = {} |
103 | | - |
104 | | - # Initialize all types as SKIPPED (when no record exists) |
105 | | - all_types = [ |
106 | | - db_models.DocumentIndexType.VECTOR, |
107 | | - db_models.DocumentIndexType.FULLTEXT, |
108 | | - db_models.DocumentIndexType.GRAPH, |
109 | | - ] |
110 | | - for index_type in all_types: |
111 | | - index_status[index_type] = "SKIPPED" |
112 | | - |
113 | | - # Update with actual states from database |
114 | | - for index in indexes: |
115 | | - index_status[index.index_type] = index.status |
116 | | - index_updated[index.index_type] = index.gmt_updated |
| 120 | + |
| 121 | + # Add document_id filter if provided (for single document query) |
| 122 | + if document_id: |
| 123 | + query = query.where(db_models.Document.id == document_id) |
| 124 | + |
| 125 | + result = await session.execute(query) |
| 126 | + rows = result.fetchall() |
| 127 | + |
| 128 | + # Group results by document and attach all index information |
| 129 | + documents_dict = {} |
| 130 | + for row in rows: |
| 131 | + doc = row.Document |
| 132 | + if doc.id not in documents_dict: |
| 133 | + documents_dict[doc.id] = doc |
| 134 | + # Initialize index information for all types |
| 135 | + doc.indexes = {"VECTOR": None, "FULLTEXT": None, "GRAPH": None} |
| 136 | + |
| 137 | + # Add index information if exists |
| 138 | + if row.index_type: |
| 139 | + doc.indexes[row.index_type] = { |
| 140 | + "index_type": row.index_type, |
| 141 | + "status": row.index_status, |
| 142 | + "created_at": row.index_created_at, |
| 143 | + "updated_at": row.index_updated_at, |
| 144 | + "error_message": row.index_error_message, |
| 145 | + } |
| 146 | + |
| 147 | + return list(documents_dict.values()) |
| 148 | + |
| 149 | + return await self.db_ops._execute_query(_execute_query) |
| 150 | + |
| 151 | + async def _build_document_response(self, document: db_models.Document) -> view_models.Document: |
| 152 | + """ |
| 153 | + Build document response object with all index types information. |
| 154 | + """ |
| 155 | + # Get all index information if available |
| 156 | + indexes = getattr(document, "indexes", {"VECTOR": None, "FULLTEXT": None, "GRAPH": None}) |
117 | 157 |
|
118 | 158 | return view_models.Document( |
119 | 159 | id=document.id, |
120 | 160 | name=document.name, |
121 | 161 | status=document.status, |
122 | | - vector_index_status=index_status.get(db_models.DocumentIndexType.VECTOR, "SKIPPED"), |
123 | | - fulltext_index_status=index_status.get(db_models.DocumentIndexType.FULLTEXT, "SKIPPED"), |
124 | | - graph_index_status=index_status.get(db_models.DocumentIndexType.GRAPH, "SKIPPED"), |
125 | | - vector_index_updated=index_updated.get(db_models.DocumentIndexType.VECTOR, None), |
126 | | - fulltext_index_updated=index_updated.get(db_models.DocumentIndexType.FULLTEXT, None), |
127 | | - graph_index_updated=index_updated.get(db_models.DocumentIndexType.GRAPH, None), |
| 162 | + # Vector index information |
| 163 | + vector_index_status=indexes["VECTOR"]["status"] if indexes["VECTOR"] else "SKIPPED", |
| 164 | + vector_index_updated=indexes["VECTOR"]["updated_at"] if indexes["VECTOR"] else None, |
| 165 | + # Fulltext index information |
| 166 | + fulltext_index_status=indexes["FULLTEXT"]["status"] if indexes["FULLTEXT"] else "SKIPPED", |
| 167 | + fulltext_index_updated=indexes["FULLTEXT"]["updated_at"] if indexes["FULLTEXT"] else None, |
| 168 | + # Graph index information |
| 169 | + graph_index_status=indexes["GRAPH"]["status"] if indexes["GRAPH"] else "SKIPPED", |
| 170 | + graph_index_updated=indexes["GRAPH"]["updated_at"] if indexes["GRAPH"] else None, |
128 | 171 | size=document.size, |
129 | 172 | created=document.gmt_created, |
130 | 173 | updated=document.gmt_updated, |
@@ -241,19 +284,24 @@ async def _create_documents_atomically(session): |
241 | 284 | return DocumentList(items=response) |
242 | 285 |
|
243 | 286 | async def list_documents(self, user: str, collection_id: str) -> view_models.DocumentList: |
244 | | - documents = await self.db_ops.query_documents([user], collection_id) |
| 287 | + """List all documents for a user in a collection.""" |
| 288 | + documents = await self._query_documents_with_indexes(user, collection_id) |
| 289 | + |
245 | 290 | response = [] |
246 | | - async for session in get_async_session(): |
247 | | - for document in documents: |
248 | | - response.append(await self.build_document_response(document, session)) |
249 | | - return DocumentList(items=response) |
| 291 | + for document in documents: |
| 292 | + response.append(await self._build_document_response(document)) |
| 293 | + |
| 294 | + return view_models.DocumentList(items=response) |
250 | 295 |
|
251 | 296 | async def get_document(self, user: str, collection_id: str, document_id: str) -> view_models.Document: |
252 | | - document = await self.db_ops.query_document(user, collection_id, document_id) |
253 | | - if document is None: |
254 | | - raise DocumentNotFoundException(document_id) |
255 | | - async for session in get_async_session(): |
256 | | - return await self.build_document_response(document, session) |
| 297 | + """Get a specific document by ID.""" |
| 298 | + documents = await self._query_documents_with_indexes(user, collection_id, document_id) |
| 299 | + |
| 300 | + if not documents: |
| 301 | + raise DocumentNotFoundException(f"Document not found: {document_id}") |
| 302 | + |
| 303 | + document = documents[0] |
| 304 | + return await self._build_document_response(document) |
257 | 305 |
|
258 | 306 | async def _delete_document(self, session: AsyncSession, user: str, collection_id: str, document_id: str): |
259 | 307 | """ |
@@ -390,7 +438,9 @@ async def get_document_chunks(self, user_id: str, collection_id: str, document_i |
390 | 438 | """ |
391 | 439 | Get all chunks of a document. |
392 | 440 | """ |
393 | | - async for session in get_async_session(): |
| 441 | + |
| 442 | + # Use database operations with proper session management |
| 443 | + async def _get_document_chunks(session): |
394 | 444 | # 1. Get the document to verify ownership and get collection_id |
395 | 445 | stmt = select(db_models.Document).filter( |
396 | 446 | db_models.Document.id == document_id, |
@@ -471,11 +521,16 @@ async def get_document_chunks(self, user_id: str, collection_id: str, document_i |
471 | 521 | ) |
472 | 522 | raise HTTPException(status_code=500, detail="Failed to retrieve chunks from vector store") |
473 | 523 |
|
| 524 | + # Execute query with proper session management |
| 525 | + return await self.db_ops._execute_query(_get_document_chunks) |
| 526 | + |
474 | 527 | async def get_document_preview(self, user_id: str, collection_id: str, document_id: str) -> DocumentPreview: |
475 | 528 | """ |
476 | 529 | Get all preview-related information for a document. |
477 | 530 | """ |
478 | | - async for session in get_async_session(): |
| 531 | + |
| 532 | + # Use database operations with proper session management |
| 533 | + async def _get_document_preview(session): |
479 | 534 | # 1. Get document and vector index in one go |
480 | 535 | doc_stmt = select(db_models.Document).filter( |
481 | 536 | db_models.Document.id == document_id, |
@@ -539,11 +594,16 @@ async def get_document_preview(self, user_id: str, collection_id: str, document_ |
539 | 594 | chunks=chunks, |
540 | 595 | ) |
541 | 596 |
|
| 597 | + # Execute query with proper session management |
| 598 | + return await self.db_ops._execute_query(_get_document_preview) |
| 599 | + |
542 | 600 | async def get_document_object(self, user_id: str, collection_id: str, document_id: str, path: str): |
543 | 601 | """ |
544 | 602 | Get a file object associated with a document from the object store. |
545 | 603 | """ |
546 | | - async for session in get_async_session(): |
| 604 | + |
| 605 | + # Use database operations with proper session management |
| 606 | + async def _get_document_object(session): |
547 | 607 | # 1. Verify user has access to the document |
548 | 608 | stmt = select(db_models.Document).filter( |
549 | 609 | db_models.Document.id == document_id, |
@@ -580,6 +640,9 @@ async def get_document_object(self, user_id: str, collection_id: str, document_i |
580 | 640 | logger.error(f"Failed to get object for document {document_id} at path {full_path}: {e}", exc_info=True) |
581 | 641 | raise HTTPException(status_code=500, detail="Failed to get object from store") |
582 | 642 |
|
| 643 | + # Execute query with proper session management |
| 644 | + return await self.db_ops._execute_query(_get_document_object) |
| 645 | + |
583 | 646 |
|
584 | 647 | # Create a global service instance for easy access |
585 | 648 | # This uses the global db_ops instance and doesn't require session management in views |
|
0 commit comments