Skip to content

Commit d4a17a8

Browse files
committed
Restore sync gateways delete API
1 parent 5e9e322 commit d4a17a8

3 files changed

Lines changed: 236 additions & 2 deletions

File tree

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import httpx
1111
from sqlalchemy import func, select, update
1212
from sqlalchemy.ext.asyncio import AsyncSession
13+
from sqlalchemy.orm import selectinload
1314

1415
import dstack._internal.utils.random_names as random_names
1516
from dstack._internal.core.backends.base.compute import (
@@ -49,6 +50,7 @@
4950
from dstack._internal.server.services import events
5051
from dstack._internal.server.services.backends import (
5152
check_backend_type_available,
53+
get_project_backend_by_type_or_error,
5254
get_project_backend_with_model_by_type_or_error,
5355
)
5456
from dstack._internal.server.services.gateways.connection import GatewayConnection
@@ -61,6 +63,7 @@
6163
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
6264
from dstack._internal.server.services.plugins import apply_plugin_policies
6365
from dstack._internal.server.utils.common import gather_map_async
66+
from dstack._internal.settings import FeatureFlags
6467
from dstack._internal.utils.common import get_current_datetime, run_async
6568
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
6669
from dstack._internal.utils.logging import get_logger
@@ -278,6 +281,32 @@ async def delete_gateways(
278281
project: ProjectModel,
279282
gateways_names: List[str],
280283
user: UserModel,
284+
):
285+
# Keep both delete code paths while pipeline processing is behind a feature flag:
286+
# - pipeline path marks gateways for async deletion by GatewayPipeline
287+
# - sync path deletes gateway resources inline for non-pipeline processing
288+
# TODO: Drop sync path after pipeline processing is enabled by default.
289+
if FeatureFlags.PIPELINE_PROCESSING_ENABLED:
290+
await _delete_gateways_pipeline(
291+
session=session,
292+
project=project,
293+
gateways_names=gateways_names,
294+
user=user,
295+
)
296+
else:
297+
await _delete_gateways_sync(
298+
session=session,
299+
project=project,
300+
gateways_names=gateways_names,
301+
user=user,
302+
)
303+
304+
305+
async def _delete_gateways_pipeline(
306+
session: AsyncSession,
307+
project: ProjectModel,
308+
gateways_names: List[str],
309+
user: UserModel,
281310
):
282311
res = await session.execute(
283312
select(GatewayModel).where(
@@ -323,6 +352,79 @@ async def delete_gateways(
323352
await session.commit()
324353

325354

355+
async def _delete_gateways_sync(
356+
session: AsyncSession,
357+
project: ProjectModel,
358+
gateways_names: List[str],
359+
user: UserModel,
360+
):
361+
res = await session.execute(
362+
select(GatewayModel).where(
363+
GatewayModel.project_id == project.id,
364+
GatewayModel.name.in_(gateways_names),
365+
)
366+
)
367+
gateway_models = res.scalars().all()
368+
gateways_ids = sorted([g.id for g in gateway_models])
369+
await session.commit()
370+
logger.info("Deleting gateways: %s", [g.name for g in gateway_models])
371+
async with get_locker(get_db().dialect_name).lock_ctx(
372+
GatewayModel.__tablename__, gateways_ids
373+
):
374+
# Refetch after lock
375+
res = await session.execute(
376+
select(GatewayModel)
377+
.where(
378+
GatewayModel.project_id == project.id,
379+
GatewayModel.name.in_(gateways_names),
380+
)
381+
.options(selectinload(GatewayModel.gateway_compute))
382+
.execution_options(populate_existing=True)
383+
.order_by(GatewayModel.id) # take locks in order
384+
.with_for_update(key_share=True)
385+
)
386+
gateway_models = res.scalars().all()
387+
for gateway_model in gateway_models:
388+
backend = await get_project_backend_by_type_or_error(
389+
project=project, backend_type=gateway_model.backend.type
390+
)
391+
compute = backend.compute()
392+
assert isinstance(compute, ComputeWithGatewaySupport)
393+
gateway_compute_configuration = get_gateway_compute_configuration(gateway_model)
394+
if (
395+
gateway_model.gateway_compute is not None
396+
and gateway_compute_configuration is not None
397+
):
398+
logger.info("Deleting gateway compute for %s...", gateway_model.name)
399+
try:
400+
await run_async(
401+
compute.terminate_gateway,
402+
gateway_model.gateway_compute.instance_id,
403+
gateway_compute_configuration,
404+
gateway_model.gateway_compute.backend_data,
405+
)
406+
except Exception:
407+
logger.exception(
408+
"Error when deleting gateway compute for %s",
409+
gateway_model.name,
410+
)
411+
continue
412+
logger.info("Deleted gateway compute for %s", gateway_model.name)
413+
if gateway_model.gateway_compute is not None:
414+
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
415+
gateway_model.gateway_compute.active = False
416+
gateway_model.gateway_compute.deleted = True
417+
session.add(gateway_model.gateway_compute)
418+
await session.delete(gateway_model)
419+
events.emit(
420+
session,
421+
"Gateway deleted",
422+
actor=events.UserActor.from_user(user),
423+
targets=[events.Target.from_model(gateway_model)],
424+
)
425+
await session.commit()
426+
427+
326428
async def set_gateway_wildcard_domain(
327429
session: AsyncSession,
328430
project: ProjectModel,

src/dstack/_internal/server/services/pipelines.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,20 @@ def hint_fetch(self, model_name: str) -> None:
1111
pass
1212

1313

14+
class _NoopPipelineHinter:
15+
def hint_fetch(self, model_name: str) -> None:
16+
pass
17+
18+
19+
_noop_pipeline_hinter = _NoopPipelineHinter()
20+
21+
1422
def get_pipeline_hinter(request: Request) -> PipelineHinterProtocol:
1523
"""
1624
Returns pipeline hinter that allows hinting replica's pipelines that there are new items for processing.
1725
This can reduce processing latency if the processing happens rarely.
1826
"""
19-
return request.app.state.pipeline_manager.hinter
27+
pipeline_manager = getattr(request.app.state, "pipeline_manager", None)
28+
if pipeline_manager is None:
29+
return _noop_pipeline_hinter
30+
return pipeline_manager.hinter

src/tests/_internal/server/routers/test_gateways.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from unittest.mock import patch
1+
from unittest.mock import Mock, patch
22

33
import pytest
44
from httpx import AsyncClient
55
from sqlalchemy.ext.asyncio import AsyncSession
66

7+
from dstack._internal.core.errors import DstackError
78
from dstack._internal.core.models.backends.base import BackendType
89
from dstack._internal.core.models.users import GlobalRole, ProjectRole
910
from dstack._internal.server.services.projects import add_project_member
1011
from dstack._internal.server.testing.common import (
12+
ComputeMockSpec,
1113
clear_events,
1214
create_backend,
1315
create_gateway,
@@ -18,6 +20,15 @@
1820
list_events,
1921
)
2022
from dstack._internal.server.testing.matchers import SomeUUID4Str
23+
from dstack._internal.settings import FeatureFlags
24+
25+
26+
@pytest.fixture
27+
def patch_pipeline_processing_flag(monkeypatch: pytest.MonkeyPatch):
28+
def _apply(enabled: bool):
29+
monkeypatch.setattr(FeatureFlags, "PIPELINE_PROCESSING_ENABLED", enabled)
30+
31+
return _apply
2132

2233

2334
class TestListAndGetGateways:
@@ -453,6 +464,12 @@ async def test_only_admin_can_delete(
453464
)
454465
assert response.status_code == 403
455466

467+
468+
class TestDeleteGatewayPipelineEnabled:
469+
@pytest.fixture(autouse=True)
470+
def _pipeline_processing_enabled(self, patch_pipeline_processing_flag):
471+
patch_pipeline_processing_flag(True)
472+
456473
@pytest.mark.asyncio
457474
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
458475
async def test_marks_gateways_to_be_deleted(
@@ -519,6 +536,110 @@ async def test_marks_gateways_to_be_deleted(
519536
assert all(e.actor_user_id == user.id for e in events)
520537

521538

539+
class TestDeleteGatewayPipelineDisabled:
540+
@pytest.fixture(autouse=True)
541+
def _pipeline_processing_disabled(self, patch_pipeline_processing_flag):
542+
patch_pipeline_processing_flag(False)
543+
544+
@pytest.mark.asyncio
545+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
546+
async def test_deletes_gateways_synchronously(
547+
self, test_db, session: AsyncSession, client: AsyncClient
548+
):
549+
user = await create_user(session, global_role=GlobalRole.USER)
550+
project = await create_project(session)
551+
await add_project_member(
552+
session=session, project=project, user=user, project_role=ProjectRole.ADMIN
553+
)
554+
backend_aws = await create_backend(session, project.id)
555+
backend_gcp = await create_backend(session, project.id, backend_type=BackendType.GCP)
556+
gateway_compute_aws = await create_gateway_compute(
557+
session=session,
558+
backend_id=backend_aws.id,
559+
)
560+
gateway_aws = await create_gateway(
561+
session=session,
562+
project_id=project.id,
563+
backend_id=backend_aws.id,
564+
name="gateway-aws",
565+
gateway_compute_id=gateway_compute_aws.id,
566+
)
567+
gateway_compute_gcp = await create_gateway_compute(
568+
session=session,
569+
backend_id=backend_gcp.id,
570+
)
571+
gateway_gcp = await create_gateway(
572+
session=session,
573+
project_id=project.id,
574+
backend_id=backend_gcp.id,
575+
name="gateway-gcp",
576+
gateway_compute_id=gateway_compute_gcp.id,
577+
)
578+
with patch(
579+
"dstack._internal.server.services.gateways.get_project_backend_by_type_or_error"
580+
) as m:
581+
aws = Mock()
582+
aws.compute.return_value = Mock(spec=ComputeMockSpec)
583+
aws.compute.return_value.terminate_gateway.return_value = None # success
584+
gcp = Mock()
585+
gcp.compute.return_value = Mock(spec=ComputeMockSpec)
586+
gcp.compute.return_value.terminate_gateway.side_effect = DstackError() # fail
587+
588+
def get_backend(project, backend_type):
589+
return {BackendType.AWS: aws, BackendType.GCP: gcp}[backend_type]
590+
591+
m.side_effect = get_backend
592+
593+
response = await client.post(
594+
f"/api/project/{project.name}/gateways/delete",
595+
json={"names": [gateway_aws.name, gateway_gcp.name]},
596+
headers=get_auth_headers(user.token),
597+
)
598+
aws.compute.return_value.terminate_gateway.assert_called_once()
599+
gcp.compute.return_value.terminate_gateway.assert_called_once()
600+
assert response.status_code == 200
601+
602+
response = await client.post(
603+
f"/api/project/{project.name}/gateways/list",
604+
headers=get_auth_headers(user.token),
605+
)
606+
assert response.status_code == 200
607+
assert response.json() == [
608+
{
609+
"id": str(gateway_gcp.id),
610+
"backend": backend_gcp.type.value,
611+
"created_at": response.json()[0]["created_at"],
612+
"default": False,
613+
"status": "submitted",
614+
"status_message": None,
615+
"instance_id": gateway_compute_gcp.instance_id,
616+
"ip_address": gateway_compute_gcp.ip_address,
617+
"hostname": gateway_compute_gcp.ip_address,
618+
"name": gateway_gcp.name,
619+
"region": gateway_gcp.region,
620+
"wildcard_domain": gateway_gcp.wildcard_domain,
621+
"configuration": {
622+
"type": "gateway",
623+
"name": gateway_gcp.name,
624+
"backend": backend_gcp.type.value,
625+
"region": gateway_gcp.region,
626+
"instance_type": None,
627+
"router": None,
628+
"domain": gateway_gcp.wildcard_domain,
629+
"default": False,
630+
"public_ip": True,
631+
"certificate": {"type": "lets-encrypt"},
632+
"tags": None,
633+
},
634+
}
635+
]
636+
637+
events = await list_events(session)
638+
assert len(events) == 1
639+
assert events[0].message == "Gateway deleted"
640+
assert events[0].targets[0].entity_name == "gateway-aws"
641+
642+
522643
class TestUpdateGateway:
523644
@pytest.mark.asyncio
524645
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)

0 commit comments

Comments
 (0)