1010import httpx
1111from sqlalchemy import func , select , update
1212from sqlalchemy .ext .asyncio import AsyncSession
13- from sqlalchemy .orm import joinedload , selectinload
13+ from sqlalchemy .orm import joinedload
1414
1515import dstack ._internal .utils .random_names as random_names
1616from dstack ._internal .core .backends .base .compute import (
@@ -131,6 +131,7 @@ async def list_project_gateways(session: AsyncSession, project: ProjectModel) ->
131131 session = session ,
132132 project = project ,
133133 load_gateway_compute = True ,
134+ load_backend_type = True ,
134135 )
135136 return [gateway_model_to_gateway (g ) for g in gateways ]
136137
@@ -143,6 +144,7 @@ async def get_gateway_by_name(
143144 project = project ,
144145 name = name ,
145146 load_gateway_compute = True ,
147+ load_backend_type = True ,
146148 )
147149 if gateway is None :
148150 return None
@@ -254,6 +256,14 @@ async def create_gateway(
254256 session = session , project = project , name = configuration .name , user = user
255257 )
256258 pipeline_hinter .hint_fetch (GatewayModel .__name__ )
259+ gateway = await get_project_gateway_model_by_name (
260+ session = session ,
261+ project = project ,
262+ name = configuration .name ,
263+ load_gateway_compute = True ,
264+ load_backend_type = True ,
265+ )
266+ assert gateway is not None
257267 return gateway_model_to_gateway (gateway )
258268
259269
@@ -392,10 +402,11 @@ async def _delete_gateways_sync(
392402 GatewayModel .project_id == project .id ,
393403 GatewayModel .name .in_ (gateways_names ),
394404 )
395- .options (selectinload (GatewayModel .gateway_compute ))
405+ .options (joinedload (GatewayModel .gateway_compute ))
406+ .options (joinedload (GatewayModel .backend ).load_only (BackendModel .type ))
396407 .execution_options (populate_existing = True )
397408 .order_by (GatewayModel .id ) # take locks in order
398- .with_for_update (key_share = True )
409+ .with_for_update (key_share = True , of = GatewayModel )
399410 )
400411 gateway_models = res .scalars ().all ()
401412 for gateway_model in gateway_models :
@@ -506,10 +517,13 @@ async def list_project_gateway_models(
506517 session : AsyncSession ,
507518 project : ProjectModel ,
508519 load_gateway_compute : bool = False ,
520+ load_backend_type : bool = False ,
509521) -> Sequence [GatewayModel ]:
510522 stmt = select (GatewayModel ).where (GatewayModel .project_id == project .id )
511523 if load_gateway_compute :
512524 stmt = stmt .options (joinedload (GatewayModel .gateway_compute ))
525+ if load_backend_type :
526+ stmt = stmt .options (joinedload (GatewayModel .backend ).load_only (BackendModel .type ))
513527 res = await session .execute (stmt )
514528 return res .scalars ().all ()
515529
@@ -519,13 +533,16 @@ async def get_project_gateway_model_by_name(
519533 project : ProjectModel ,
520534 name : str ,
521535 load_gateway_compute : bool = False ,
536+ load_backend_type : bool = False ,
522537) -> Optional [GatewayModel ]:
523538 stmt = select (GatewayModel ).where (
524539 GatewayModel .project_id == project .id ,
525540 GatewayModel .name == name ,
526541 )
527542 if load_gateway_compute :
528543 stmt = stmt .options (joinedload (GatewayModel .gateway_compute ))
544+ if load_backend_type :
545+ stmt = stmt .options (joinedload (GatewayModel .backend ).load_only (BackendModel .type ))
529546 res = await session .execute (stmt )
530547 return res .scalar ()
531548
@@ -558,6 +575,7 @@ async def get_project_gateway_model_by_name_for_update(
558575 select (GatewayModel )
559576 .where (GatewayModel .id .in_ ([gateway_id ]), * filters )
560577 .options (joinedload (GatewayModel .gateway_compute ))
578+ .options (joinedload (GatewayModel .backend ).load_only (BackendModel .type ))
561579 .with_for_update (key_share = True , of = GatewayModel )
562580 )
563581 yield res .scalar_one_or_none ()
@@ -567,13 +585,16 @@ async def get_project_default_gateway_model(
567585 session : AsyncSession ,
568586 project : ProjectModel ,
569587 load_gateway_compute : bool = False ,
588+ load_backend_type : bool = False ,
570589) -> Optional [GatewayModel ]:
571590 stmt = select (GatewayModel ).where (
572591 GatewayModel .id == project .default_gateway_id ,
573592 GatewayModel .to_be_deleted == False ,
574593 )
575594 if load_gateway_compute :
576595 stmt = stmt .options (joinedload (GatewayModel .gateway_compute ))
596+ if load_backend_type :
597+ stmt = stmt .options (joinedload (GatewayModel .backend ).load_only (BackendModel .type ))
577598 res = await session .execute (stmt )
578599 return res .scalar_one_or_none ()
579600
0 commit comments