Skip to content

Commit 8aa43ea

Browse files
carlos4sasukaminato0721autofix-ci[bot]
authored andcommitted
refactor: use sessionmaker in small services 2 (langgenius#34696)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent fde411b commit 8aa43ea

12 files changed

Lines changed: 86 additions & 53 deletions

api/services/account_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pydantic import BaseModel, TypeAdapter
1111
from sqlalchemy import delete, func, select, update
12-
from sqlalchemy.orm import Session
12+
from sqlalchemy.orm import Session, sessionmaker
1313

1414

1515
class InvitationData(TypedDict):
@@ -1516,7 +1516,7 @@ def invite_new_member(
15161516

15171517
check_workspace_member_invite_permission(tenant.id)
15181518

1519-
with Session(db.engine) as session:
1519+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
15201520
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
15211521

15221522
if not account:

api/services/async_workflow_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from celery.result import AsyncResult
1313
from sqlalchemy import select
14-
from sqlalchemy.orm import Session
14+
from sqlalchemy.orm import Session, sessionmaker
1515

1616
from enums.quota_type import QuotaType
1717
from extensions.ext_database import db
@@ -237,7 +237,7 @@ def get_trigger_log(
237237
Returns:
238238
Trigger log as dictionary or None if not found
239239
"""
240-
with Session(db.engine) as session:
240+
with sessionmaker(db.engine).begin() as session:
241241
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
242242
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
243243

@@ -263,7 +263,7 @@ def get_recent_logs(
263263
Returns:
264264
List of trigger logs as dictionaries
265265
"""
266-
with Session(db.engine) as session:
266+
with sessionmaker(db.engine).begin() as session:
267267
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
268268
logs = trigger_log_repo.get_recent_logs(
269269
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
@@ -286,7 +286,7 @@ def get_failed_logs_for_retry(
286286
Returns:
287287
List of failed trigger logs as dictionaries
288288
"""
289-
with Session(db.engine) as session:
289+
with sessionmaker(db.engine).begin() as session:
290290
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
291291
logs = trigger_log_repo.get_failed_for_retry(
292292
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit

api/services/clear_free_plan_tenant_expired_logs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def process(cls, days: int, batch: int, tenant_ids: list[str]):
346346
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
347347
current_time = started_at
348348

349-
with Session(db.engine) as session:
349+
with sessionmaker(db.engine).begin() as session:
350350
total_tenant_count = session.query(Tenant.id).count()
351351

352352
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
@@ -398,7 +398,7 @@ def process_tenant(flask_app: Flask, tenant_id: str):
398398
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
399399
interval = datetime.timedelta(days=1)
400400
# Process tenants in this batch
401-
with Session(db.engine) as session:
401+
with sessionmaker(db.engine).begin() as session:
402402
# Calculate tenant count in next batch with current interval
403403
# Try different intervals until we find one with a reasonable tenant count
404404
test_intervals = [

api/services/credit_pool_service.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from sqlalchemy import select, update
4-
from sqlalchemy.orm import Session
4+
from sqlalchemy.orm import sessionmaker
55

66
from configs import dify_config
77
from core.errors.error import QuotaExceededError
@@ -71,7 +71,7 @@ def check_and_deduct_credits(
7171
actual_credits = min(credits_required, pool.remaining_credits)
7272

7373
try:
74-
with Session(db.engine) as session:
74+
with sessionmaker(db.engine).begin() as session:
7575
stmt = (
7676
update(TenantCreditPool)
7777
.where(
@@ -81,7 +81,6 @@ def check_and_deduct_credits(
8181
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
8282
)
8383
session.execute(stmt)
84-
session.commit()
8584
except Exception:
8685
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
8786
raise QuotaExceededError("Failed to deduct credits")

api/services/dataset_service.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
1616
from redis.exceptions import LockNotOwnedError
1717
from sqlalchemy import delete, exists, func, select, update
18-
from sqlalchemy.orm import Session
18+
from sqlalchemy.orm import Session, sessionmaker
1919
from werkzeug.exceptions import Forbidden, NotFound
2020

2121
from configs import dify_config
@@ -551,22 +551,22 @@ def _update_external_knowledge_binding(dataset_id, external_knowledge_id, extern
551551
external_knowledge_id: External knowledge identifier
552552
external_knowledge_api_id: External knowledge API identifier
553553
"""
554-
with Session(db.engine) as session:
554+
with sessionmaker(db.engine).begin() as session:
555555
external_knowledge_binding = (
556556
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
557557
)
558558

559559
if not external_knowledge_binding:
560560
raise ValueError("External knowledge binding not found.")
561561

562-
# Update binding if values have changed
563-
if (
564-
external_knowledge_binding.external_knowledge_id != external_knowledge_id
565-
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
566-
):
567-
external_knowledge_binding.external_knowledge_id = external_knowledge_id
568-
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
569-
db.session.add(external_knowledge_binding)
562+
# Update binding if values have changed
563+
if (
564+
external_knowledge_binding.external_knowledge_id != external_knowledge_id
565+
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
566+
):
567+
external_knowledge_binding.external_knowledge_id = external_knowledge_id
568+
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
569+
session.add(external_knowledge_binding)
570570

571571
@staticmethod
572572
def _update_internal_dataset(dataset, data, user):

api/services/oauth_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uuid
33

44
from sqlalchemy import select
5-
from sqlalchemy.orm import Session
5+
from sqlalchemy.orm import sessionmaker
66
from werkzeug.exceptions import BadRequest
77

88
from extensions.ext_database import db
@@ -29,7 +29,7 @@ class OAuthServerService:
2929
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
3030
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
3131

32-
with Session(db.engine) as session:
32+
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
3333
return session.execute(query).scalar_one_or_none()
3434

3535
@staticmethod

api/services/rag_pipeline/rag_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
11821182
workflow = db.session.get(Workflow, pipeline.workflow_id)
11831183
if not workflow:
11841184
raise ValueError("Workflow not found")
1185-
with Session(db.engine) as session:
1185+
with sessionmaker(db.engine).begin() as session:
11861186
dataset = pipeline.retrieve_dataset(session=session)
11871187
if not dataset:
11881188
raise ValueError("Dataset not found")
@@ -1209,7 +1209,7 @@ def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict):
12091209

12101210
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
12111211

1212-
with Session(db.engine) as session:
1212+
with sessionmaker(db.engine).begin() as session:
12131213
rag_pipeline_dsl_service = RagPipelineDslService(session)
12141214
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
12151215
if args.get("icon_info") is None:

api/services/workflow_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def run_draft_workflow_node(
834834
if workflow_node_execution is None:
835835
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
836836

837-
with Session(db.engine) as session:
837+
with sessionmaker(db.engine).begin() as session:
838838
outputs = workflow_node_execution.load_full_outputs(session, storage)
839839

840840
with Session(bind=db.engine) as session, session.begin():

api/tests/unit_tests/services/test_account_service.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,16 +1427,18 @@ def test_invite_new_member_new_account(self, mock_db_dependencies, mock_redis_de
14271427
mock_tenant.name = "Test Workspace"
14281428
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
14291429

1430-
# Mock database queries - need to mock the Session query
1430+
# Mock database queries - need to mock the sessionmaker query
14311431
mock_session = MagicMock()
14321432
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
14331433

1434+
mock_sessionmaker = MagicMock()
1435+
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
1436+
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
1437+
14341438
with (
1435-
patch("services.account_service.Session") as mock_session_class,
1439+
patch("services.account_service.sessionmaker", mock_sessionmaker),
14361440
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
14371441
):
1438-
mock_session_class.return_value.__enter__.return_value = mock_session
1439-
mock_session_class.return_value.__exit__.return_value = None
14401442
mock_lookup.return_value = None
14411443

14421444
# Mock RegisterService.register
@@ -1485,12 +1487,14 @@ def test_invite_new_member_normalizes_new_account_email(
14851487
mixed_email = "Invitee@Example.com"
14861488

14871489
mock_session = MagicMock()
1490+
mock_sessionmaker = MagicMock()
1491+
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
1492+
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
1493+
14881494
with (
1489-
patch("services.account_service.Session") as mock_session_class,
1495+
patch("services.account_service.sessionmaker", mock_sessionmaker),
14901496
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
14911497
):
1492-
mock_session_class.return_value.__enter__.return_value = mock_session
1493-
mock_session_class.return_value.__exit__.return_value = None
14941498
mock_lookup.return_value = None
14951499

14961500
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
@@ -1541,16 +1545,18 @@ def test_invite_new_member_existing_account(
15411545
account_id="existing-user-456", email="existing@example.com", status="pending"
15421546
)
15431547

1544-
# Mock database queries - need to mock the Session query
1548+
# Mock database queries - need to mock the sessionmaker query
15451549
mock_session = MagicMock()
15461550
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
15471551

1552+
mock_sessionmaker = MagicMock()
1553+
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
1554+
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
1555+
15481556
with (
1549-
patch("services.account_service.Session") as mock_session_class,
1557+
patch("services.account_service.sessionmaker", mock_sessionmaker),
15501558
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
15511559
):
1552-
mock_session_class.return_value.__enter__.return_value = mock_session
1553-
mock_session_class.return_value.__exit__.return_value = None
15541560
mock_lookup.return_value = mock_existing_account
15551561

15561562
# Mock scalar for TenantAccountJoin lookup - no existing member

api/tests/unit_tests/services/test_async_workflow_service.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,12 @@ def test_should_return_trigger_log_dict_or_none(self, repo_result, expected):
357357
mock_session_context.__enter__.return_value = mock_session
358358
mock_session_context.__exit__.return_value = None
359359

360+
mock_sessionmaker = MagicMock()
361+
mock_sessionmaker.return_value.begin.return_value = mock_session_context
362+
360363
with (
361364
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
362-
patch.object(
363-
async_workflow_service_module, "Session", return_value=mock_session_context
364-
) as mock_session_class,
365+
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
365366
patch.object(
366367
async_workflow_service_module,
367368
"SQLAlchemyWorkflowTriggerLogRepository",
@@ -373,7 +374,7 @@ def test_should_return_trigger_log_dict_or_none(self, repo_result, expected):
373374

374375
# Assert
375376
assert result == expected
376-
mock_session_class.assert_called_once_with(fake_engine)
377+
mock_sessionmaker.assert_called_once_with(fake_engine)
377378
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
378379

379380
def test_should_return_recent_logs_as_dict_list(self):
@@ -391,9 +392,12 @@ def test_should_return_recent_logs_as_dict_list(self):
391392
mock_session_context.__enter__.return_value = mock_session
392393
mock_session_context.__exit__.return_value = None
393394

395+
mock_sessionmaker = MagicMock()
396+
mock_sessionmaker.return_value.begin.return_value = mock_session_context
397+
394398
with (
395399
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
396-
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
400+
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
397401
patch.object(
398402
async_workflow_service_module,
399403
"SQLAlchemyWorkflowTriggerLogRepository",
@@ -432,9 +436,12 @@ def test_should_return_failed_logs_for_retry_as_dict_list(self):
432436
mock_session_context.__enter__.return_value = mock_session
433437
mock_session_context.__exit__.return_value = None
434438

439+
mock_sessionmaker = MagicMock()
440+
mock_sessionmaker.return_value.begin.return_value = mock_session_context
441+
435442
with (
436443
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
437-
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
444+
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
438445
patch.object(
439446
async_workflow_service_module,
440447
"SQLAlchemyWorkflowTriggerLogRepository",

0 commit comments

Comments
 (0)