Skip to content

Commit 94de210

Browse files
committed
Make gateway deletion async
1 parent 8c5f34e commit 94de210

4 files changed

Lines changed: 217 additions & 48 deletions

File tree

src/dstack/_internal/server/background/pipeline_tasks/gateways.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from datetime import timedelta
55
from typing import Optional, Sequence
66

7-
from sqlalchemy import or_, select, update
7+
from sqlalchemy import delete, or_, select, update
88
from sqlalchemy.orm import joinedload, load_only
99

10+
from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport
1011
from dstack._internal.core.errors import BackendError, BackendNotAvailable
1112
from dstack._internal.core.models.gateways import GatewayStatus
1213
from dstack._internal.server.background.pipeline_tasks.base import (
@@ -24,14 +25,16 @@
2425
GatewayComputeModel,
2526
GatewayModel,
2627
ProjectModel,
28+
UserModel,
2729
)
2830
from dstack._internal.server.services import backends as backends_services
31+
from dstack._internal.server.services import events
2932
from dstack._internal.server.services import gateways as gateways_services
3033
from dstack._internal.server.services.gateways import emit_gateway_status_change_event
3134
from dstack._internal.server.services.gateways.pool import gateway_connections_pool
3235
from dstack._internal.server.services.locking import get_locker
3336
from dstack._internal.server.services.logging import fmt
34-
from dstack._internal.utils.common import get_current_datetime
37+
from dstack._internal.utils.common import get_current_datetime, run_async
3538
from dstack._internal.utils.logging import get_logger
3639

3740
logger = get_logger(__name__)
@@ -40,6 +43,7 @@
4043
@dataclass
4144
class GatewayPipelineItem(PipelineItem):
4245
status: GatewayStatus
46+
to_be_deleted: bool
4347

4448

4549
class GatewayPipeline(Pipeline[GatewayPipelineItem]):
@@ -121,8 +125,11 @@ async def fetch(self, limit: int) -> list[GatewayPipelineItem]:
121125
res = await session.execute(
122126
select(GatewayModel)
123127
.where(
124-
GatewayModel.status.in_(
125-
[GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]
128+
or_(
129+
GatewayModel.status.in_(
130+
[GatewayStatus.SUBMITTED, GatewayStatus.PROVISIONING]
131+
),
132+
GatewayModel.to_be_deleted == True,
126133
),
127134
or_(
128135
GatewayModel.last_processed_at <= now - self._min_processing_interval,
@@ -143,9 +150,10 @@ async def fetch(self, limit: int) -> list[GatewayPipelineItem]:
143150
.options(
144151
load_only(
145152
GatewayModel.id,
146-
GatewayModel.status,
147153
GatewayModel.lock_token,
148154
GatewayModel.lock_expires_at,
155+
GatewayModel.status,
156+
GatewayModel.to_be_deleted,
149157
)
150158
)
151159
)
@@ -166,6 +174,7 @@ async def fetch(self, limit: int) -> list[GatewayPipelineItem]:
166174
lock_token=lock_token,
167175
prev_lock_expired=prev_lock_expired,
168176
status=gateway_model.status,
177+
to_be_deleted=gateway_model.to_be_deleted,
169178
)
170179
)
171180
await session.commit()
@@ -184,7 +193,9 @@ def __init__(
184193
)
185194

186195
async def process(self, item: GatewayPipelineItem):
187-
if item.status == GatewayStatus.SUBMITTED:
196+
if item.to_be_deleted:
197+
await _process_to_be_deleted_item(item)
198+
elif item.status == GatewayStatus.SUBMITTED:
188199
await _process_submitted_item(item)
189200
elif item.status == GatewayStatus.PROVISIONING:
190201
await _process_provisioning_item(item)
@@ -235,6 +246,7 @@ async def _process_submitted_item(item: GatewayPipelineItem):
235246
item.__tablename__,
236247
item.id,
237248
)
249+
# TODO: Clean up gateway_compute_model.
238250
return
239251
emit_gateway_status_change_event(
240252
session=session,
@@ -407,3 +419,132 @@ async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _Provisi
407419
return _ProvisioningResult(
408420
gateway_update_map={"status": GatewayStatus.RUNNING},
409421
)
422+
423+
424+
async def _process_to_be_deleted_item(item: GatewayPipelineItem):
425+
async with get_session_ctx() as session:
426+
res = await session.execute(
427+
select(GatewayModel)
428+
.where(
429+
GatewayModel.id == item.id,
430+
GatewayModel.lock_token == item.lock_token,
431+
)
432+
.options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends))
433+
.options(joinedload(GatewayModel.gateway_compute))
434+
.options(
435+
joinedload(GatewayModel.deleted_by_user).load_only(UserModel.id, UserModel.name)
436+
)
437+
)
438+
gateway_model = res.unique().scalar_one_or_none()
439+
if gateway_model is None:
440+
logger.warning(
441+
"Failed to process %s item %s: lock_token mismatch."
442+
" The item is expected to be processed and updated on another fetch iteration.",
443+
item.__tablename__,
444+
item.id,
445+
)
446+
return
447+
448+
result = await _process_to_be_deleted_gateway(gateway_model)
449+
async with get_session_ctx() as session:
450+
if result.delete_gateway:
451+
res = await session.execute(
452+
delete(GatewayModel)
453+
.where(
454+
GatewayModel.id == gateway_model.id,
455+
GatewayModel.lock_token == gateway_model.lock_token,
456+
)
457+
.returning(GatewayModel.id)
458+
)
459+
deleted_ids = list(res.scalars().all())
460+
if len(deleted_ids) == 0:
461+
logger.warning(
462+
"Failed to delete %s item %s after processing: lock_token changed."
463+
" The item is expected to be processed and deleted on another fetch iteration.",
464+
item.__tablename__,
465+
item.id,
466+
)
467+
return
468+
actor = events.SystemActor()
469+
if gateway_model.deleted_by_user is not None:
470+
actor = events.UserActor.from_user(gateway_model.deleted_by_user)
471+
events.emit(
472+
session,
473+
"Gateway deleted",
474+
actor=actor,
475+
targets=[events.Target.from_model(gateway_model)],
476+
)
477+
else:
478+
res = await session.execute(
479+
update(GatewayModel)
480+
.where(
481+
GatewayModel.id == gateway_model.id,
482+
GatewayModel.lock_token == gateway_model.lock_token,
483+
)
484+
.values(**get_processed_update_map())
485+
.returning(GatewayModel.id)
486+
)
487+
updated_ids = list(res.scalars().all())
488+
if len(updated_ids) == 0:
489+
logger.warning(
490+
"Failed to update %s item %s after processing: lock_token changed."
491+
" The item is expected to be processed and updated on another fetch iteration.",
492+
item.__tablename__,
493+
item.id,
494+
)
495+
return
496+
497+
if result.gateway_compute_update_map:
498+
res = await session.execute(
499+
update(GatewayComputeModel)
500+
.where(GatewayComputeModel.id == gateway_model.gateway_compute_id)
501+
.values(**result.gateway_compute_update_map)
502+
.returning(GatewayComputeModel.id)
503+
)
504+
updated_ids = list(res.scalars().all())
505+
if len(updated_ids) == 0:
506+
logger.error(
507+
"Failed to update compute model %s for gateway %s."
508+
" This is unexpected and may happen only if the compute model was manually deleted.",
509+
gateway_model.id,
510+
item.id,
511+
)
512+
return
513+
514+
515+
@dataclass
516+
class _DeletedResult:
517+
delete_gateway: bool
518+
gateway_compute_update_map: UpdateMap = field(default_factory=dict)
519+
520+
521+
async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _DeletedResult:
522+
backend = await backends_services.get_project_backend_by_type_or_error(
523+
project=gateway_model.project, backend_type=gateway_model.backend.type
524+
)
525+
compute = backend.compute()
526+
assert isinstance(compute, ComputeWithGatewaySupport)
527+
gateway_compute_configuration = gateways_services.get_gateway_compute_configuration(
528+
gateway_model
529+
)
530+
if gateway_model.gateway_compute is not None and gateway_compute_configuration is not None:
531+
logger.info("Deleting gateway compute for %s...", gateway_model.name)
532+
try:
533+
await run_async(
534+
compute.terminate_gateway,
535+
gateway_model.gateway_compute.instance_id,
536+
gateway_compute_configuration,
537+
gateway_model.gateway_compute.backend_data,
538+
)
539+
except Exception:
540+
logger.exception(
541+
"Error when deleting gateway compute for %s",
542+
gateway_model.name,
543+
)
544+
return _DeletedResult(delete_gateway=False)
545+
logger.info("Deleted gateway compute for %s", gateway_model.name)
546+
result = _DeletedResult(delete_gateway=True)
547+
if gateway_model.gateway_compute is not None:
548+
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
549+
result.gateway_compute_update_map = {"active": False, "deleted": True}
550+
return result
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Add GatewayModel deleted columns
2+
3+
Revision ID: d911914ecf17
4+
Revises: 66c2fdda33c2
5+
Create Date: 2026-02-20 08:40:11.616560+00:00
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
import sqlalchemy_utils
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "d911914ecf17"
15+
down_revision = "66c2fdda33c2"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
with op.batch_alter_table("gateways", schema=None) as batch_op:
23+
batch_op.add_column(
24+
sa.Column("to_be_deleted", sa.Boolean(), server_default=sa.text("0"), nullable=False)
25+
)
26+
batch_op.add_column(
27+
sa.Column(
28+
"deleted_by_user_id",
29+
sqlalchemy_utils.types.uuid.UUIDType(binary=False),
30+
nullable=True,
31+
)
32+
)
33+
batch_op.create_foreign_key(
34+
batch_op.f("fk_gateways_deleted_by_user_id_users"),
35+
"users",
36+
["deleted_by_user_id"],
37+
["id"],
38+
ondelete="SET NULL",
39+
)
40+
41+
# ### end Alembic commands ###
42+
43+
44+
def downgrade() -> None:
45+
# ### commands auto generated by Alembic - please adjust! ###
46+
with op.batch_alter_table("gateways", schema=None) as batch_op:
47+
batch_op.drop_constraint(
48+
batch_op.f("fk_gateways_deleted_by_user_id_users"), type_="foreignkey"
49+
)
50+
batch_op.drop_column("deleted_by_user_id")
51+
batch_op.drop_column("to_be_deleted")
52+
53+
# ### end Alembic commands ###

src/dstack/_internal/server/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,12 @@ class GatewayModel(PipelineModelMixin, BaseModel):
521521

522522
runs: Mapped[List["RunModel"]] = relationship(back_populates="gateway")
523523

524+
to_be_deleted: Mapped[bool] = mapped_column(Boolean, server_default=false())
525+
deleted_by_user_id: Mapped[Optional[uuid.UUID]] = mapped_column(
526+
ForeignKey("users.id", ondelete="SET NULL")
527+
)
528+
deleted_by_user: Mapped[Optional[UserModel]] = relationship(foreign_keys=[deleted_by_user_id])
529+
524530
__table_args__ = (UniqueConstraint("project_id", "name", name="uq_gateways_project_id_name"),)
525531

526532

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import httpx
1111
from sqlalchemy import func, select, update
1212
from sqlalchemy.ext.asyncio import AsyncSession
13-
from sqlalchemy.orm import selectinload
1413

1514
import dstack._internal.utils.random_names as random_names
1615
from dstack._internal.core.backends.base.compute import (
@@ -50,7 +49,6 @@
5049
from dstack._internal.server.services import events
5150
from dstack._internal.server.services.backends import (
5251
check_backend_type_available,
53-
get_project_backend_by_type_or_error,
5452
get_project_backend_with_model_by_type_or_error,
5553
)
5654
from dstack._internal.server.services.gateways.connection import GatewayConnection
@@ -295,53 +293,24 @@ async def delete_gateways(
295293
res = await session.execute(
296294
select(GatewayModel)
297295
.where(
296+
GatewayModel.id.in_(gateways_ids),
298297
GatewayModel.project_id == project.id,
299-
GatewayModel.name.in_(gateways_names),
298+
GatewayModel.lock_expires_at.is_(None),
300299
)
301-
.options(selectinload(GatewayModel.gateway_compute))
302300
.execution_options(populate_existing=True)
303301
.order_by(GatewayModel.id) # take locks in order
304-
.with_for_update(key_share=True)
302+
.with_for_update(key_share=True, nowait=True)
305303
)
306304
gateway_models = res.scalars().all()
307-
for gateway_model in gateway_models:
308-
backend = await get_project_backend_by_type_or_error(
309-
project=project, backend_type=gateway_model.backend.type
310-
)
311-
compute = backend.compute()
312-
assert isinstance(compute, ComputeWithGatewaySupport)
313-
gateway_compute_configuration = get_gateway_compute_configuration(gateway_model)
314-
if (
315-
gateway_model.gateway_compute is not None
316-
and gateway_compute_configuration is not None
317-
):
318-
logger.info("Deleting gateway compute for %s...", gateway_model.name)
319-
try:
320-
await run_async(
321-
compute.terminate_gateway,
322-
gateway_model.gateway_compute.instance_id,
323-
gateway_compute_configuration,
324-
gateway_model.gateway_compute.backend_data,
325-
)
326-
except Exception:
327-
logger.exception(
328-
"Error when deleting gateway compute for %s",
329-
gateway_model.name,
330-
)
331-
continue
332-
logger.info("Deleted gateway compute for %s", gateway_model.name)
333-
if gateway_model.gateway_compute is not None:
334-
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
335-
gateway_model.gateway_compute.active = False
336-
gateway_model.gateway_compute.deleted = True
337-
session.add(gateway_model.gateway_compute)
338-
await session.delete(gateway_model)
339-
events.emit(
340-
session,
341-
"Gateway deleted",
342-
actor=events.UserActor.from_user(user),
343-
targets=[events.Target.from_model(gateway_model)],
305+
if len(gateway_models) != len(gateways_ids):
306+
# TODO: Make the delete endpoint fully async without lock – put the request in queue and process in background.
307+
raise ServerClientError(
308+
"Failed to delete gateways: gateways are being processed currently. Try again later."
344309
)
310+
for gateway_model in gateway_models:
311+
if not gateway_model.to_be_deleted:
312+
gateway_model.to_be_deleted = True
313+
gateway_model.deleted_by_user_id = user.id
345314
await session.commit()
346315

347316

0 commit comments

Comments
 (0)