Skip to content

Commit f8425f4

Browse files
committed
Remove implicit gateway.backend load
1 parent 55a96d6 commit f8425f4

5 files changed

Lines changed: 37 additions & 5 deletions

File tree

src/dstack/_internal/server/background/pipeline_tasks/gateways.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from dstack._internal.server.db import get_db, get_session_ctx
2525
from dstack._internal.server.models import (
26+
BackendModel,
2627
GatewayComputeModel,
2728
GatewayModel,
2829
ProjectModel,
@@ -210,6 +211,7 @@ async def _process_submitted_item(item: GatewayPipelineItem):
210211
GatewayModel.lock_token == item.lock_token,
211212
)
212213
.options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends))
214+
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
213215
)
214216
gateway_model = res.unique().scalar_one_or_none()
215217
if gateway_model is None:
@@ -431,6 +433,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem):
431433
)
432434
.options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends))
433435
.options(joinedload(GatewayModel.gateway_compute))
436+
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
434437
)
435438
gateway_model = res.unique().scalar_one_or_none()
436439
if gateway_model is None:

src/dstack/_internal/server/background/scheduled_tasks/gateways.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from dstack._internal.core.errors import BackendError, BackendNotAvailable, SSHError
88
from dstack._internal.core.models.gateways import GatewayStatus
99
from dstack._internal.server.db import get_db, get_session_ctx
10-
from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel
10+
from dstack._internal.server.models import (
11+
BackendModel,
12+
GatewayComputeModel,
13+
GatewayModel,
14+
ProjectModel,
15+
)
1116
from dstack._internal.server.services import backends as backends_services
1217
from dstack._internal.server.services import gateways as gateways_services
1318
from dstack._internal.server.services.gateways import (
@@ -109,6 +114,7 @@ async def _process_submitted_gateway(session: AsyncSession, gateway_model: Gatew
109114
select(GatewayModel)
110115
.where(GatewayModel.id == gateway_model.id)
111116
.options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends))
117+
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
112118
.execution_options(populate_existing=True)
113119
)
114120
gateway_model = res.unique().scalar_one()

src/dstack/_internal/server/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ class GatewayModel(PipelineModelMixin, BaseModel):
513513
project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
514514
project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id])
515515
backend_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("backends.id", ondelete="CASCADE"))
516-
backend: Mapped["BackendModel"] = relationship(lazy="selectin")
516+
backend: Mapped["BackendModel"] = relationship()
517517

518518
gateway_compute_id: Mapped[Optional[uuid.UUID]] = mapped_column(
519519
ForeignKey("gateway_computes.id", ondelete="CASCADE")

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import httpx
1111
from sqlalchemy import func, select, update
1212
from sqlalchemy.ext.asyncio import AsyncSession
13-
from sqlalchemy.orm import joinedload, selectinload
13+
from sqlalchemy.orm import joinedload
1414

1515
import dstack._internal.utils.random_names as random_names
1616
from dstack._internal.core.backends.base.compute import (
@@ -131,6 +131,7 @@ async def list_project_gateways(session: AsyncSession, project: ProjectModel) ->
131131
session=session,
132132
project=project,
133133
load_gateway_compute=True,
134+
load_backend_type=True,
134135
)
135136
return [gateway_model_to_gateway(g) for g in gateways]
136137

@@ -143,6 +144,7 @@ async def get_gateway_by_name(
143144
project=project,
144145
name=name,
145146
load_gateway_compute=True,
147+
load_backend_type=True,
146148
)
147149
if gateway is None:
148150
return None
@@ -254,6 +256,14 @@ async def create_gateway(
254256
session=session, project=project, name=configuration.name, user=user
255257
)
256258
pipeline_hinter.hint_fetch(GatewayModel.__name__)
259+
gateway = await get_project_gateway_model_by_name(
260+
session=session,
261+
project=project,
262+
name=configuration.name,
263+
load_gateway_compute=True,
264+
load_backend_type=True,
265+
)
266+
assert gateway is not None
257267
return gateway_model_to_gateway(gateway)
258268

259269

@@ -392,10 +402,11 @@ async def _delete_gateways_sync(
392402
GatewayModel.project_id == project.id,
393403
GatewayModel.name.in_(gateways_names),
394404
)
395-
.options(selectinload(GatewayModel.gateway_compute))
405+
.options(joinedload(GatewayModel.gateway_compute))
406+
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
396407
.execution_options(populate_existing=True)
397408
.order_by(GatewayModel.id) # take locks in order
398-
.with_for_update(key_share=True)
409+
.with_for_update(key_share=True, of=GatewayModel)
399410
)
400411
gateway_models = res.scalars().all()
401412
for gateway_model in gateway_models:
@@ -506,10 +517,13 @@ async def list_project_gateway_models(
506517
session: AsyncSession,
507518
project: ProjectModel,
508519
load_gateway_compute: bool = False,
520+
load_backend_type: bool = False,
509521
) -> Sequence[GatewayModel]:
510522
stmt = select(GatewayModel).where(GatewayModel.project_id == project.id)
511523
if load_gateway_compute:
512524
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
525+
if load_backend_type:
526+
stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
513527
res = await session.execute(stmt)
514528
return res.scalars().all()
515529

@@ -519,13 +533,16 @@ async def get_project_gateway_model_by_name(
519533
project: ProjectModel,
520534
name: str,
521535
load_gateway_compute: bool = False,
536+
load_backend_type: bool = False,
522537
) -> Optional[GatewayModel]:
523538
stmt = select(GatewayModel).where(
524539
GatewayModel.project_id == project.id,
525540
GatewayModel.name == name,
526541
)
527542
if load_gateway_compute:
528543
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
544+
if load_backend_type:
545+
stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
529546
res = await session.execute(stmt)
530547
return res.scalar()
531548

@@ -558,6 +575,7 @@ async def get_project_gateway_model_by_name_for_update(
558575
select(GatewayModel)
559576
.where(GatewayModel.id.in_([gateway_id]), *filters)
560577
.options(joinedload(GatewayModel.gateway_compute))
578+
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
561579
.with_for_update(key_share=True, of=GatewayModel)
562580
)
563581
yield res.scalar_one_or_none()
@@ -567,13 +585,16 @@ async def get_project_default_gateway_model(
567585
session: AsyncSession,
568586
project: ProjectModel,
569587
load_gateway_compute: bool = False,
588+
load_backend_type: bool = False,
570589
) -> Optional[GatewayModel]:
571590
stmt = select(GatewayModel).where(
572591
GatewayModel.id == project.default_gateway_id,
573592
GatewayModel.to_be_deleted == False,
574593
)
575594
if load_gateway_compute:
576595
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
596+
if load_backend_type:
597+
stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
577598
res = await session.execute(stmt)
578599
return res.scalar_one_or_none()
579600

src/dstack/_internal/server/services/services/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec:
9494
project=run_model.project,
9595
name=run_spec.configuration.gateway,
9696
load_gateway_compute=True,
97+
load_backend_type=True,
9798
)
9899
if gateway is None:
99100
raise ResourceNotExistsError(
@@ -110,6 +111,7 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec:
110111
session=session,
111112
project=run_model.project,
112113
load_gateway_compute=True,
114+
load_backend_type=True,
113115
)
114116
if gateway is None and run_spec.configuration.gateway == True:
115117
raise ResourceNotExistsError(

0 commit comments

Comments
 (0)