Skip to content

Commit 60ffcb0

Browse files
authored
refactor(db): introduce ContextVar DAO layer and optimize transaction granularity (#678)
* fix(email): support RFC-compliant non-ascii filename encoding for attachments * refactor(db): introduce ContextVar DAO layer and shrink transaction boundaries in auth API * update
1 parent 15ac89c commit 60ffcb0

42 files changed

Lines changed: 2085 additions & 1402 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backend/app/api/agents.py

Lines changed: 298 additions & 232 deletions
Large diffs are not rendered by default.

backend/app/api/auth.py

Lines changed: 507 additions & 580 deletions
Large diffs are not rendered by default.

backend/app/api/dingtalk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ async def dingtalk_callback(
449449
pass
450450

451451
# 2. Get DingTalk provider config
452-
auth_provider = await auth_provider_registry.get_provider(db, "dingtalk", str(tenant_id) if tenant_id else None)
452+
auth_provider = await auth_provider_registry.get_provider("dingtalk", str(tenant_id) if tenant_id else None)
453453
if not auth_provider:
454454
return HTMLResponse("Auth failed: DingTalk provider not configured for this tenant")
455455

backend/app/api/google_workspace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def _handle_google_sso_callback(
7979
auth_provider = GoogleWorkspaceAuthProvider(provider=provider, config=provider.config or {})
8080
else:
8181
auth_provider = await auth_provider_registry.get_provider(
82-
db, "google_workspace", str(tenant_id) if tenant_id else None
82+
"google_workspace", str(tenant_id) if tenant_id else None
8383
)
8484
if not auth_provider:
8585
return HTMLResponse("Auth failed: Google Workspace provider not configured for this tenant")

backend/app/api/notification.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ async def broadcast_notification(
183183
from app.services.system_email_service import (
184184
BroadcastEmailRecipient,
185185
deliver_broadcast_emails,
186-
run_background_email_job,
187186
)
188187

189188
for user in users:
@@ -205,7 +204,7 @@ async def broadcast_notification(
205204

206205
await db.commit()
207206
if email_recipients:
208-
background_tasks.add_task(run_background_email_job, deliver_broadcast_emails, email_recipients)
207+
background_tasks.add_task(deliver_broadcast_emails, email_recipients)
209208
return {
210209
"ok": True,
211210
"users_notified": count_users,

backend/app/api/organization.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ async def admin_update_user(
9797
if "email" in update_data or "primary_mobile" in update_data:
9898
from app.services.registration_service import registration_service
9999
await registration_service.sync_org_member_contact_from_user(
100-
db,
101100
user,
102101
sync_email="email" in update_data,
103102
sync_phone="primary_mobile" in update_data,

backend/app/api/sso.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
125125

126126
elif p.provider_type == "dingtalk":
127127
from app.services.auth_registry import auth_provider_registry
128-
auth_provider = await auth_provider_registry.get_provider(db, "dingtalk", str(session.tenant_id) if session.tenant_id else None)
128+
auth_provider = await auth_provider_registry.get_provider("dingtalk", str(session.tenant_id) if session.tenant_id else None)
129129
if auth_provider:
130130
redir = f"{public_base}/api/auth/dingtalk/callback"
131131
# Use provider's standardized authorization URL
@@ -147,7 +147,7 @@ async def get_sso_config(sid: uuid.UUID, request: Request, db: AsyncSession = De
147147
sign_google_sso_state,
148148
)
149149
auth_provider = await auth_provider_registry.get_provider(
150-
db, "google_workspace", str(session.tenant_id) if session.tenant_id else None
150+
"google_workspace", str(session.tenant_id) if session.tenant_id else None
151151
)
152152
if auth_provider:
153153
redir = await get_google_redirect_uri(db, p, request)

backend/app/api/wecom.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,6 @@ async def wecom_callback(
710710
# 2. Extract user info and login/register via RegistrationService
711711
try:
712712
auth_provider = await auth_provider_registry.get_provider(
713-
db,
714713
"wecom",
715714
str(tenant_id) if tenant_id else (str(provider.tenant_id) if provider.tenant_id else None),
716715
)

backend/app/dao/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from app.dao.identity_dao import identity_dao
2+
from app.dao.identity_provider_dao import identity_provider_dao
3+
from app.dao.invitation_code_dao import invitation_code_dao
4+
from app.dao.org_member_dao import org_member_dao
5+
from app.dao.participant_dao import participant_dao
6+
from app.dao.system_setting_dao import system_setting_dao
7+
from app.dao.tenant_dao import tenant_dao
8+
from app.dao.user_dao import user_dao
9+
10+
__all__ = [
11+
"identity_dao",
12+
"identity_provider_dao",
13+
"invitation_code_dao",
14+
"org_member_dao",
15+
"participant_dao",
16+
"system_setting_dao",
17+
"tenant_dao",
18+
"user_dao",
19+
]

backend/app/dao/base.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from collections.abc import AsyncGenerator, Sequence
2+
from contextlib import asynccontextmanager
3+
from typing import Any, Generic, Type, TypeVar
4+
5+
from sqlalchemy import select
6+
from sqlalchemy.ext.asyncio import AsyncSession
7+
8+
from app.database import Base, _session_ctx, async_session
9+
10+
ModelType = TypeVar("ModelType", bound=Base)
11+
12+
13+
class BaseDAO(Generic[ModelType]):
14+
"""Base class for data access objects, managing session context and basic CRUD."""
15+
16+
def __init__(self, model: Type[ModelType]):
17+
self.model = model
18+
19+
@asynccontextmanager
20+
async def session(self) -> AsyncGenerator[AsyncSession, None]:
21+
"""Context manager yielding the active context session or a new one."""
22+
context_session = _session_ctx.get()
23+
if context_session is not None:
24+
yield context_session
25+
else:
26+
async with async_session() as session:
27+
yield session
28+
29+
async def get(self, id: Any) -> ModelType | None:
30+
"""Fetch a single record by its primary key ID."""
31+
async with self.session() as db:
32+
if hasattr(db, "get"):
33+
return await db.get(self.model, id)
34+
# Fallback for custom mock DB clients in tests
35+
stmt = select(self.model).where(self.model.id == id)
36+
result = await db.execute(stmt)
37+
return result.scalar_one_or_none()
38+
39+
async def is_empty(self) -> bool:
40+
"""Check if the table is empty (no records)."""
41+
async with self.session() as db:
42+
stmt = select(self.model.id).limit(1)
43+
result = await db.execute(stmt)
44+
return result.scalar() is None
45+
46+
async def get_all(self, skip: int = 0, limit: int = 100) -> Sequence[ModelType]:
47+
"""Fetch all records with offset and limit."""
48+
async with self.session() as db:
49+
stmt = select(self.model).offset(skip).limit(limit)
50+
result = await db.execute(stmt)
51+
return result.scalars().all()
52+
53+
async def create(self, *, obj_in: dict[str, Any]) -> ModelType:
54+
"""Create a new record."""
55+
async with self.session() as db:
56+
db_obj = self.model(**obj_in)
57+
db.add(db_obj)
58+
await db.flush()
59+
return db_obj
60+
61+
async def update(self, *, db_obj: ModelType, obj_in: dict[str, Any]) -> ModelType:
62+
"""Update an existing record."""
63+
async with self.session() as db:
64+
for field, value in obj_in.items():
65+
if hasattr(db_obj, field):
66+
setattr(db_obj, field, value)
67+
db.add(db_obj)
68+
await db.flush()
69+
return db_obj
70+
71+
async def delete(self, *, id: Any) -> ModelType | None:
72+
"""Delete a record by ID."""
73+
async with self.session() as db:
74+
obj = await self.get(id)
75+
if obj:
76+
if hasattr(db, "delete"):
77+
await db.delete(obj)
78+
await db.flush()
79+
return obj
80+

0 commit comments

Comments
 (0)