Skip to content

Commit 2b4f2bf

Browse files
authored
feat: enhance database connection management and document handling (#1091)
- Added database connection pool settings to `config.py` for better resource management. - Updated `async_engine` and `sync_engine` creation to utilize new pool settings. - Improved session management in `AuditService` and `DocumentService` for database operations. - Refactored document retrieval methods to use eager loading for indexes, optimizing performance. - Added relationship mapping for `Document` and `DocumentIndex` models to streamline data access.
1 parent 36bc39c commit 2b4f2bf

5 files changed

Lines changed: 198 additions & 82 deletions

File tree

aperag/config.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ class Config(BaseSettings):
5151
# Database
5252
database_url: str = Field(f"sqlite:///{BASE_DIR}/db.sqlite3", alias="DATABASE_URL")
5353

54+
# Database connection pool settings
55+
db_pool_size: int = Field(20, alias="DB_POOL_SIZE")
56+
db_max_overflow: int = Field(40, alias="DB_MAX_OVERFLOW")
57+
db_pool_timeout: int = Field(60, alias="DB_POOL_TIMEOUT")
58+
db_pool_recycle: int = Field(3600, alias="DB_POOL_RECYCLE")
59+
db_pool_pre_ping: bool = Field(True, alias="DB_POOL_PRE_PING")
60+
5461
# Auth
5562
auth_type: str = Field("none", alias="AUTH_TYPE")
5663
auth0_domain: str = Field("aperag-dev.auting.cn", alias="AUTH0_DOMAIN")
@@ -174,8 +181,25 @@ def get_async_database_url(url: str):
174181

175182
settings = Config()
176183

177-
async_engine = create_async_engine(get_async_database_url(settings.database_url), echo=settings.debug)
178-
sync_engine = create_engine(get_sync_database_url(settings.database_url), echo=settings.debug)
184+
# Database connection pool settings from configuration
185+
async_engine = create_async_engine(
186+
get_async_database_url(settings.database_url),
187+
echo=settings.debug,
188+
pool_size=settings.db_pool_size,
189+
max_overflow=settings.db_max_overflow,
190+
pool_timeout=settings.db_pool_timeout,
191+
pool_recycle=settings.db_pool_recycle,
192+
pool_pre_ping=settings.db_pool_pre_ping,
193+
)
194+
sync_engine = create_engine(
195+
get_sync_database_url(settings.database_url),
196+
echo=settings.debug,
197+
pool_size=settings.db_pool_size,
198+
max_overflow=settings.db_max_overflow,
199+
pool_timeout=settings.db_pool_timeout,
200+
pool_recycle=settings.db_pool_recycle,
201+
pool_pre_ping=settings.db_pool_pre_ping,
202+
)
179203

180204

181205
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:

aperag/service/audit_service.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,16 @@ async def log_audit(
159159
request_id=request_id or str(uuid.uuid4()),
160160
)
161161

162-
# Save to database asynchronously
163-
async for session in get_async_session():
162+
# Save to database with proper session management
163+
async def _save_audit_log(session):
164164
session.add(audit_log)
165165
await session.commit()
166+
return audit_log
167+
168+
# Use get_async_session with proper session management
169+
async for session in get_async_session():
170+
await _save_audit_log(session)
171+
break # Only process one session
166172

167173
except Exception as e:
168174
logger.error(f"Failed to log audit: {e}")
@@ -179,7 +185,9 @@ async def list_audit_logs(
179185
limit: int = 1000,
180186
) -> List[AuditLog]:
181187
"""List audit logs with filtering"""
182-
async for session in get_async_session():
188+
189+
# Use proper session management
190+
async def _list_audit_logs(session):
183191
# Build query
184192
stmt = select(AuditLog)
185193

@@ -206,33 +214,39 @@ async def list_audit_logs(
206214
# Order by creation time (newest first) and limit
207215
stmt = stmt.order_by(desc(AuditLog.gmt_created)).limit(limit)
208216

209-
# Execute query
217+
# Execute query and return results immediately
210218
result = await session.execute(stmt)
211-
audit_logs = result.scalars().all()
212-
213-
# Extract resource_id for each log during query time
214-
for log in audit_logs:
215-
if log.resource_type and log.path:
216-
# Convert string to enum if needed
217-
resource_type_enum = log.resource_type
218-
if isinstance(log.resource_type, str):
219-
try:
220-
resource_type_enum = AuditResource(log.resource_type)
221-
except ValueError:
222-
resource_type_enum = None
223-
224-
if resource_type_enum:
225-
log.resource_id = self.extract_resource_id_from_path(log.path, resource_type_enum)
226-
else:
227-
log.resource_id = None
228-
229-
# Calculate duration if both times are available
230-
if log.start_time and log.end_time:
231-
log.duration_ms = log.end_time - log.start_time
219+
return result.scalars().all()
220+
221+
# Execute query with proper session management
222+
audit_logs = None
223+
async for session in get_async_session():
224+
audit_logs = await _list_audit_logs(session)
225+
break # Only process one session
226+
227+
# Post-process audit logs outside of session to avoid long session occupation
228+
for log in audit_logs:
229+
if log.resource_type and log.path:
230+
# Convert string to enum if needed
231+
resource_type_enum = log.resource_type
232+
if isinstance(log.resource_type, str):
233+
try:
234+
resource_type_enum = AuditResource(log.resource_type)
235+
except ValueError:
236+
resource_type_enum = None
237+
238+
if resource_type_enum:
239+
log.resource_id = self.extract_resource_id_from_path(log.path, resource_type_enum)
232240
else:
233-
log.duration_ms = None
241+
log.resource_id = None
242+
243+
# Calculate duration if both times are available
244+
if log.start_time and log.end_time:
245+
log.duration_ms = log.end_time - log.start_time
246+
else:
247+
log.duration_ms = None
234248

235-
return audit_logs
249+
return audit_logs
236250

237251

238252
# Global audit service instance

aperag/service/document_service.py

Lines changed: 116 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sqlalchemy import select
2424
from sqlalchemy.ext.asyncio import AsyncSession
2525

26-
from aperag.config import get_async_session, settings
26+
from aperag.config import settings
2727
from aperag.db import models as db_models
2828
from aperag.db.ops import AsyncDatabaseOps, async_db_ops
2929
from aperag.docparser.doc_parser import DocParser
@@ -79,52 +79,95 @@ def __init__(self, session: AsyncSession = None):
7979
else:
8080
self.db_ops = AsyncDatabaseOps(session) # Create custom instance for transaction control
8181

82-
async def build_document_response(
83-
self, document: db_models.Document, session: AsyncSession
84-
) -> view_models.Document:
85-
"""Build Document response object for API return using new status model."""
86-
from sqlalchemy import select
87-
88-
from aperag.db.models import DocumentIndex
82+
async def _query_documents_with_indexes(
83+
self, user: str, collection_id: str, document_id: str = None
84+
) -> List[db_models.Document]:
85+
"""
86+
Common function to query documents with their indexes using JOIN.
87+
If document_id is provided, query single document, otherwise query all documents.
88+
"""
8989

90-
# Get all document indexes for status calculation
91-
document_indexes = await session.execute(
92-
select(DocumentIndex).where(
93-
DocumentIndex.document_id == document.id,
94-
DocumentIndex.status != db_models.DocumentIndexStatus.DELETING,
95-
DocumentIndex.status != db_models.DocumentIndexStatus.DELETION_IN_PROGRESS,
90+
async def _execute_query(session):
91+
from sqlalchemy import and_, outerjoin, select
92+
93+
# Create JOIN query between Document and DocumentIndex tables
94+
# Use outerjoin to get all documents even if they don't have indexes
95+
query = (
96+
select(
97+
db_models.Document,
98+
db_models.DocumentIndex.index_type,
99+
db_models.DocumentIndex.status.label("index_status"),
100+
db_models.DocumentIndex.gmt_created.label("index_created_at"),
101+
db_models.DocumentIndex.gmt_updated.label("index_updated_at"),
102+
db_models.DocumentIndex.error_message.label("index_error_message"),
103+
)
104+
.select_from(
105+
outerjoin(
106+
db_models.Document,
107+
db_models.DocumentIndex,
108+
db_models.Document.id == db_models.DocumentIndex.document_id,
109+
)
110+
)
111+
.where(
112+
and_(
113+
db_models.Document.user == user,
114+
db_models.Document.collection_id == collection_id,
115+
db_models.Document.status != db_models.DocumentStatus.DELETED,
116+
)
117+
)
118+
.order_by(db_models.Document.gmt_created.desc())
96119
)
97-
)
98-
indexes = document_indexes.scalars().all()
99-
100-
# Map index states to API response format
101-
index_status = {}
102-
index_updated = {}
103-
104-
# Initialize all types as SKIPPED (when no record exists)
105-
all_types = [
106-
db_models.DocumentIndexType.VECTOR,
107-
db_models.DocumentIndexType.FULLTEXT,
108-
db_models.DocumentIndexType.GRAPH,
109-
]
110-
for index_type in all_types:
111-
index_status[index_type] = "SKIPPED"
112-
113-
# Update with actual states from database
114-
for index in indexes:
115-
index_status[index.index_type] = index.status
116-
index_updated[index.index_type] = index.gmt_updated
120+
121+
# Add document_id filter if provided (for single document query)
122+
if document_id:
123+
query = query.where(db_models.Document.id == document_id)
124+
125+
result = await session.execute(query)
126+
rows = result.fetchall()
127+
128+
# Group results by document and attach all index information
129+
documents_dict = {}
130+
for row in rows:
131+
doc = row.Document
132+
if doc.id not in documents_dict:
133+
documents_dict[doc.id] = doc
134+
# Initialize index information for all types
135+
doc.indexes = {"VECTOR": None, "FULLTEXT": None, "GRAPH": None}
136+
137+
# Add index information if exists
138+
if row.index_type:
139+
doc.indexes[row.index_type] = {
140+
"index_type": row.index_type,
141+
"status": row.index_status,
142+
"created_at": row.index_created_at,
143+
"updated_at": row.index_updated_at,
144+
"error_message": row.index_error_message,
145+
}
146+
147+
return list(documents_dict.values())
148+
149+
return await self.db_ops._execute_query(_execute_query)
150+
151+
async def _build_document_response(self, document: db_models.Document) -> view_models.Document:
152+
"""
153+
Build document response object with all index types information.
154+
"""
155+
# Get all index information if available
156+
indexes = getattr(document, "indexes", {"VECTOR": None, "FULLTEXT": None, "GRAPH": None})
117157

118158
return view_models.Document(
119159
id=document.id,
120160
name=document.name,
121161
status=document.status,
122-
vector_index_status=index_status.get(db_models.DocumentIndexType.VECTOR, "SKIPPED"),
123-
fulltext_index_status=index_status.get(db_models.DocumentIndexType.FULLTEXT, "SKIPPED"),
124-
graph_index_status=index_status.get(db_models.DocumentIndexType.GRAPH, "SKIPPED"),
125-
vector_index_updated=index_updated.get(db_models.DocumentIndexType.VECTOR, None),
126-
fulltext_index_updated=index_updated.get(db_models.DocumentIndexType.FULLTEXT, None),
127-
graph_index_updated=index_updated.get(db_models.DocumentIndexType.GRAPH, None),
162+
# Vector index information
163+
vector_index_status=indexes["VECTOR"]["status"] if indexes["VECTOR"] else "SKIPPED",
164+
vector_index_updated=indexes["VECTOR"]["updated_at"] if indexes["VECTOR"] else None,
165+
# Fulltext index information
166+
fulltext_index_status=indexes["FULLTEXT"]["status"] if indexes["FULLTEXT"] else "SKIPPED",
167+
fulltext_index_updated=indexes["FULLTEXT"]["updated_at"] if indexes["FULLTEXT"] else None,
168+
# Graph index information
169+
graph_index_status=indexes["GRAPH"]["status"] if indexes["GRAPH"] else "SKIPPED",
170+
graph_index_updated=indexes["GRAPH"]["updated_at"] if indexes["GRAPH"] else None,
128171
size=document.size,
129172
created=document.gmt_created,
130173
updated=document.gmt_updated,
@@ -241,19 +284,24 @@ async def _create_documents_atomically(session):
241284
return DocumentList(items=response)
242285

243286
async def list_documents(self, user: str, collection_id: str) -> view_models.DocumentList:
244-
documents = await self.db_ops.query_documents([user], collection_id)
287+
"""List all documents for a user in a collection."""
288+
documents = await self._query_documents_with_indexes(user, collection_id)
289+
245290
response = []
246-
async for session in get_async_session():
247-
for document in documents:
248-
response.append(await self.build_document_response(document, session))
249-
return DocumentList(items=response)
291+
for document in documents:
292+
response.append(await self._build_document_response(document))
293+
294+
return view_models.DocumentList(items=response)
250295

251296
async def get_document(self, user: str, collection_id: str, document_id: str) -> view_models.Document:
252-
document = await self.db_ops.query_document(user, collection_id, document_id)
253-
if document is None:
254-
raise DocumentNotFoundException(document_id)
255-
async for session in get_async_session():
256-
return await self.build_document_response(document, session)
297+
"""Get a specific document by ID."""
298+
documents = await self._query_documents_with_indexes(user, collection_id, document_id)
299+
300+
if not documents:
301+
raise DocumentNotFoundException(f"Document not found: {document_id}")
302+
303+
document = documents[0]
304+
return await self._build_document_response(document)
257305

258306
async def _delete_document(self, session: AsyncSession, user: str, collection_id: str, document_id: str):
259307
"""
@@ -390,7 +438,9 @@ async def get_document_chunks(self, user_id: str, collection_id: str, document_i
390438
"""
391439
Get all chunks of a document.
392440
"""
393-
async for session in get_async_session():
441+
442+
# Use database operations with proper session management
443+
async def _get_document_chunks(session):
394444
# 1. Get the document to verify ownership and get collection_id
395445
stmt = select(db_models.Document).filter(
396446
db_models.Document.id == document_id,
@@ -471,11 +521,16 @@ async def get_document_chunks(self, user_id: str, collection_id: str, document_i
471521
)
472522
raise HTTPException(status_code=500, detail="Failed to retrieve chunks from vector store")
473523

524+
# Execute query with proper session management
525+
return await self.db_ops._execute_query(_get_document_chunks)
526+
474527
async def get_document_preview(self, user_id: str, collection_id: str, document_id: str) -> DocumentPreview:
475528
"""
476529
Get all preview-related information for a document.
477530
"""
478-
async for session in get_async_session():
531+
532+
# Use database operations with proper session management
533+
async def _get_document_preview(session):
479534
# 1. Get document and vector index in one go
480535
doc_stmt = select(db_models.Document).filter(
481536
db_models.Document.id == document_id,
@@ -539,11 +594,16 @@ async def get_document_preview(self, user_id: str, collection_id: str, document_
539594
chunks=chunks,
540595
)
541596

597+
# Execute query with proper session management
598+
return await self.db_ops._execute_query(_get_document_preview)
599+
542600
async def get_document_object(self, user_id: str, collection_id: str, document_id: str, path: str):
543601
"""
544602
Get a file object associated with a document from the object store.
545603
"""
546-
async for session in get_async_session():
604+
605+
# Use database operations with proper session management
606+
async def _get_document_object(session):
547607
# 1. Verify user has access to the document
548608
stmt = select(db_models.Document).filter(
549609
db_models.Document.id == document_id,
@@ -580,6 +640,9 @@ async def get_document_object(self, user_id: str, collection_id: str, document_i
580640
logger.error(f"Failed to get object for document {document_id} at path {full_path}: {e}", exc_info=True)
581641
raise HTTPException(status_code=500, detail="Failed to get object from store")
582642

643+
# Execute query with proper session management
644+
return await self.db_ops._execute_query(_get_document_object)
645+
583646

584647
# Create a global service instance for easy access
585648
# This uses the global db_ops instance and doesn't require session management in views

envs/docker.env.overrides

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ VECTOR_DB_CONTEXT={"url":"http://aperag-qdrant", "port":6333, "distance":"Cosine
66
ES_HOST=http://aperag-es:9200
77
MEMORY_REDIS_URL=redis://default:password@aperag-redis:6379
88

9+
# Database Connection Pool Settings for Docker deployment
10+
DB_POOL_SIZE=25
11+
DB_MAX_OVERFLOW=50
12+
DB_POOL_TIMEOUT=60
13+
DB_POOL_RECYCLE=3600
14+
DB_POOL_PRE_PING=True
15+
916
# Override for path
1017
TIKTOKEN_CACHE_DIR=/root/.cache/tiktoken
1118

0 commit comments

Comments
 (0)