Skip to content

Commit 55a96d6

Browse files
committed
Remove implicit gateway_compute load
1 parent 1d2e835 commit 55a96d6

6 files changed

Lines changed: 69 additions & 19 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/gateways.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ async def _process_provisioning_gateway(
153153
res = await session.execute(
154154
select(GatewayModel)
155155
.where(GatewayModel.id == gateway_model.id)
156+
.options(joinedload(GatewayModel.gateway_compute))
156157
.execution_options(populate_existing=True)
157158
)
158159
gateway_model = res.unique().scalar_one()

src/dstack/_internal/server/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ class GatewayModel(PipelineModelMixin, BaseModel):
518518
gateway_compute_id: Mapped[Optional[uuid.UUID]] = mapped_column(
519519
ForeignKey("gateway_computes.id", ondelete="CASCADE")
520520
)
521-
gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship(lazy="joined")
521+
gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship()
522522

523523
runs: Mapped[List["RunModel"]] = relationship(back_populates="gateway")
524524

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,23 @@ def get_gateway_status_change_message(
127127

128128

129129
async def list_project_gateways(session: AsyncSession, project: ProjectModel) -> List[Gateway]:
130-
gateways = await list_project_gateway_models(session=session, project=project)
130+
gateways = await list_project_gateway_models(
131+
session=session,
132+
project=project,
133+
load_gateway_compute=True,
134+
)
131135
return [gateway_model_to_gateway(g) for g in gateways]
132136

133137

134138
async def get_gateway_by_name(
135139
session: AsyncSession, project: ProjectModel, name: str
136140
) -> Optional[Gateway]:
137-
gateway = await get_project_gateway_model_by_name(session=session, project=project, name=name)
141+
gateway = await get_project_gateway_model_by_name(
142+
session=session,
143+
project=project,
144+
name=name,
145+
load_gateway_compute=True,
146+
)
138147
if gateway is None:
139148
return None
140149
return gateway_model_to_gateway(gateway)
@@ -494,20 +503,30 @@ async def set_default_gateway(
494503

495504

496505
async def list_project_gateway_models(
497-
session: AsyncSession, project: ProjectModel
506+
session: AsyncSession,
507+
project: ProjectModel,
508+
load_gateway_compute: bool = False,
498509
) -> Sequence[GatewayModel]:
499-
res = await session.execute(select(GatewayModel).where(GatewayModel.project_id == project.id))
510+
stmt = select(GatewayModel).where(GatewayModel.project_id == project.id)
511+
if load_gateway_compute:
512+
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
513+
res = await session.execute(stmt)
500514
return res.scalars().all()
501515

502516

503517
async def get_project_gateway_model_by_name(
504-
session: AsyncSession, project: ProjectModel, name: str
518+
session: AsyncSession,
519+
project: ProjectModel,
520+
name: str,
521+
load_gateway_compute: bool = False,
505522
) -> Optional[GatewayModel]:
506-
res = await session.execute(
507-
select(GatewayModel).where(
508-
GatewayModel.project_id == project.id, GatewayModel.name == name
509-
)
523+
stmt = select(GatewayModel).where(
524+
GatewayModel.project_id == project.id,
525+
GatewayModel.name == name,
510526
)
527+
if load_gateway_compute:
528+
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
529+
res = await session.execute(stmt)
511530
return res.scalar()
512531

513532

@@ -538,20 +557,24 @@ async def get_project_gateway_model_by_name_for_update(
538557
res = await session.execute(
539558
select(GatewayModel)
540559
.where(GatewayModel.id.in_([gateway_id]), *filters)
560+
.options(joinedload(GatewayModel.gateway_compute))
541561
.with_for_update(key_share=True, of=GatewayModel)
542562
)
543563
yield res.scalar_one_or_none()
544564

545565

546566
async def get_project_default_gateway_model(
547-
session: AsyncSession, project: ProjectModel
567+
session: AsyncSession,
568+
project: ProjectModel,
569+
load_gateway_compute: bool = False,
548570
) -> Optional[GatewayModel]:
549-
res = await session.execute(
550-
select(GatewayModel).where(
551-
GatewayModel.id == project.default_gateway_id,
552-
GatewayModel.to_be_deleted == False,
553-
)
571+
stmt = select(GatewayModel).where(
572+
GatewayModel.id == project.default_gateway_id,
573+
GatewayModel.to_be_deleted == False,
554574
)
575+
if load_gateway_compute:
576+
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
577+
res = await session.execute(stmt)
555578
return res.scalar_one_or_none()
556579

557580

@@ -567,7 +590,12 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
567590
async def get_or_add_gateway_connection(
568591
session: AsyncSession, gateway_id: uuid.UUID
569592
) -> tuple[GatewayModel, GatewayConnection]:
570-
gateway = await session.get(GatewayModel, gateway_id)
593+
gateway = await session.get(
594+
GatewayModel,
595+
gateway_id,
596+
options=[joinedload(GatewayModel.gateway_compute)],
597+
populate_existing=True,
598+
)
571599
if gateway is None:
572600
raise GatewayError("Gateway not found")
573601
if gateway.gateway_compute is None:

src/dstack/_internal/server/services/services/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec:
9090

9191
if isinstance(run_spec.configuration.gateway, str):
9292
gateway = await get_project_gateway_model_by_name(
93-
session=session, project=run_model.project, name=run_spec.configuration.gateway
93+
session=session,
94+
project=run_model.project,
95+
name=run_spec.configuration.gateway,
96+
load_gateway_compute=True,
9497
)
9598
if gateway is None:
9699
raise ResourceNotExistsError(
@@ -104,7 +107,9 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec:
104107
gateway = None
105108
else:
106109
gateway = await get_project_default_gateway_model(
107-
session=session, project=run_model.project
110+
session=session,
111+
project=run_model.project,
112+
load_gateway_compute=True,
108113
)
109114
if gateway is None and run_spec.configuration.gateway == True:
110115
raise ResourceNotExistsError(

src/tests/_internal/server/background/pipeline_tasks/test_gateways.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from sqlalchemy import select
77
from sqlalchemy.ext.asyncio import AsyncSession
8+
from sqlalchemy.orm import joinedload
89

910
from dstack._internal.core.errors import BackendError
1011
from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus
@@ -77,6 +78,12 @@ async def test_submitted_to_provisioning(
7778
aws.compute.return_value.create_gateway.assert_called_once()
7879

7980
await session.refresh(gateway)
81+
res = await session.execute(
82+
select(GatewayModel)
83+
.where(GatewayModel.id == gateway.id)
84+
.options(joinedload(GatewayModel.gateway_compute))
85+
)
86+
gateway = res.unique().scalar_one()
8087
assert gateway.status == GatewayStatus.PROVISIONING
8188
assert gateway.gateway_compute is not None
8289
assert gateway.gateway_compute.ip_address == "2.2.2.2"

src/tests/_internal/server/background/scheduled_tasks/test_gateways.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from unittest.mock import MagicMock, Mock, patch
22

33
import pytest
4+
from sqlalchemy import select
45
from sqlalchemy.ext.asyncio import AsyncSession
6+
from sqlalchemy.orm import joinedload
57

68
from dstack._internal.core.errors import BackendError
79
from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus
810
from dstack._internal.server.background.scheduled_tasks.gateways import process_gateways
11+
from dstack._internal.server.models import GatewayModel
912
from dstack._internal.server.testing.common import (
1013
AsyncContextManager,
1114
ComputeMockSpec,
@@ -44,6 +47,12 @@ async def test_submitted_to_provisioning(self, test_db, session: AsyncSession):
4447
m.assert_called_once()
4548
aws.compute.return_value.create_gateway.assert_called_once()
4649
await session.refresh(gateway)
50+
res = await session.execute(
51+
select(GatewayModel)
52+
.where(GatewayModel.id == gateway.id)
53+
.options(joinedload(GatewayModel.gateway_compute))
54+
)
55+
gateway = res.unique().scalar_one()
4756
assert gateway.status == GatewayStatus.PROVISIONING
4857
assert gateway.gateway_compute is not None
4958
assert gateway.gateway_compute.ip_address == "2.2.2.2"

0 commit comments

Comments
 (0)