Skip to content

Commit 69c14ad

Browse files
committed
Do not load ProjectModel.default_gateway by default
1 parent 30016c4 commit 69c14ad

File tree

4 files changed

+15
-45
lines changed

4 files changed

+15
-45
lines changed

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,6 @@ async def get_gateway_by_name(
8686
return gateway_model_to_gateway(gateway)
8787

8888

89-
async def get_project_default_gateway(
90-
session: AsyncSession, project: ProjectModel
91-
) -> Optional[Gateway]:
92-
gateway: Optional[GatewayModel] = project.default_gateway
93-
if gateway is None:
94-
return None
95-
return gateway_model_to_gateway(gateway)
96-
97-
9889
async def create_gateway_compute(
9990
project_name: str,
10091
backend_compute: Compute,
@@ -181,9 +172,9 @@ async def create_gateway(
181172
session.add(gateway)
182173
await session.commit()
183174

184-
if project.default_gateway is None or configuration.default:
175+
default_gateway = await get_project_default_gateway_model(session=session, project=project)
176+
if default_gateway is None or configuration.default:
185177
await set_default_gateway(session=session, project=project, name=configuration.name)
186-
187178
return gateway_model_to_gateway(gateway)
188179

189180

@@ -349,6 +340,15 @@ async def get_project_gateway_model_by_name(
349340
return res.scalar()
350341

351342

343+
async def get_project_default_gateway_model(
344+
session: AsyncSession, project: ProjectModel
345+
) -> Optional[GatewayModel]:
346+
res = await session.execute(
347+
select(GatewayModel).where(GatewayModel.id == project.default_gateway_id)
348+
)
349+
return res.scalar_one_or_none()
350+
351+
352352
async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> str:
353353
gateways = await list_project_gateway_models(session=session, project=project)
354354
names = {g.name for g in gateways}

src/dstack/_internal/server/services/projects.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ async def get_project_model_by_name(
435435
.where(*filters)
436436
.options(joinedload(ProjectModel.backends))
437437
.options(joinedload(ProjectModel.members))
438-
.options(joinedload(ProjectModel.default_gateway))
439438
)
440439
return res.unique().scalar()
441440

@@ -452,7 +451,6 @@ async def get_project_model_by_name_or_error(
452451
)
453452
.options(joinedload(ProjectModel.backends))
454453
.options(joinedload(ProjectModel.members))
455-
.options(joinedload(ProjectModel.default_gateway))
456454
)
457455
return res.unique().scalar_one()
458456

@@ -469,7 +467,6 @@ async def get_project_model_by_id_or_error(
469467
)
470468
.options(joinedload(ProjectModel.backends))
471469
.options(joinedload(ProjectModel.members))
472-
.options(joinedload(ProjectModel.default_gateway))
473470
)
474471
return res.unique().scalar_one()
475472

src/dstack/_internal/server/services/services/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from dstack._internal.server.services.gateways import (
2929
get_gateway_configuration,
3030
get_or_add_gateway_connection,
31+
get_project_default_gateway_model,
3132
get_project_gateway_model_by_name,
3233
)
3334
from dstack._internal.server.services.logging import fmt
@@ -52,7 +53,9 @@ async def register_service(session: AsyncSession, run_model: RunModel, run_spec:
5253
elif run_spec.configuration.gateway == False:
5354
gateway = None
5455
else:
55-
gateway = run_model.project.default_gateway
56+
gateway = await get_project_default_gateway_model(
57+
session=session, project=run_model.project
58+
)
5659

5760
if gateway is not None:
5861
service_spec = await _register_service_in_gateway(session, run_model, run_spec, gateway)

src/tests/_internal/server/routers/test_gateways.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
from dstack._internal.core.errors import DstackError
88
from dstack._internal.core.models.backends.base import BackendType
99
from dstack._internal.core.models.users import GlobalRole, ProjectRole
10-
from dstack._internal.server.services.gateways import (
11-
gateway_model_to_gateway,
12-
get_project_default_gateway,
13-
)
1410
from dstack._internal.server.services.projects import add_project_member
1511
from dstack._internal.server.testing.common import (
1612
ComputeMockSpec,
@@ -291,32 +287,6 @@ async def test_create_gateway_missing_backend(
291287

292288

293289
class TestDefaultGateway:
294-
@pytest.mark.asyncio
295-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
296-
async def test_get_default_gateway(self, test_db, session: AsyncSession, client: AsyncClient):
297-
project = await create_project(session)
298-
backend = await create_backend(session, project.id)
299-
gateway = await create_gateway(session, project.id, backend.id)
300-
async with session.begin():
301-
project.default_gateway_id = gateway.id
302-
session.add(project)
303-
304-
res = await get_project_default_gateway(session, project)
305-
assert res is not None
306-
assert res.dict() == gateway_model_to_gateway(gateway).dict()
307-
308-
@pytest.mark.asyncio
309-
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
310-
async def test_default_gateway_is_missing(
311-
self, test_db, session: AsyncSession, client: AsyncClient
312-
):
313-
project = await create_project(session)
314-
backend = await create_backend(session, project.id)
315-
await create_gateway(session, project.id, backend.id)
316-
317-
res = await get_project_default_gateway(session, project)
318-
assert res is None
319-
320290
@pytest.mark.asyncio
321291
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
322292
async def test_only_admin_can_set_default_gateway(

0 commit comments

Comments
 (0)