Skip to content

Commit bd79e47

Browse files
committed
Add TestGatewayWorkerDeleted
1 parent 94de210 commit bd79e47

2 files changed

Lines changed: 108 additions & 2 deletions

File tree

src/tests/_internal/server/background/pipeline_tasks/test_gateways.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import MagicMock, Mock, patch
44

55
import pytest
6+
from sqlalchemy import select
67
from sqlalchemy.ext.asyncio import AsyncSession
78

89
from dstack._internal.core.errors import BackendError
@@ -11,13 +12,15 @@
1112
GatewayPipelineItem,
1213
GatewayWorker,
1314
)
15+
from dstack._internal.server.models import GatewayModel
1416
from dstack._internal.server.testing.common import (
1517
AsyncContextManager,
1618
ComputeMockSpec,
1719
create_backend,
1820
create_gateway,
1921
create_gateway_compute,
2022
create_project,
23+
create_user,
2124
list_events,
2225
)
2326

@@ -27,7 +30,7 @@ def worker() -> GatewayWorker:
2730
return GatewayWorker(queue=Mock(), heartbeater=Mock())
2831

2932

30-
def _gateway_to_pipeline_item(gateway_model) -> GatewayPipelineItem:
33+
def _gateway_to_pipeline_item(gateway_model: GatewayModel) -> GatewayPipelineItem:
3134
assert gateway_model.lock_token is not None
3235
assert gateway_model.lock_expires_at is not None
3336
return GatewayPipelineItem(
@@ -37,6 +40,7 @@ def _gateway_to_pipeline_item(gateway_model) -> GatewayPipelineItem:
3740
lock_expires_at=gateway_model.lock_expires_at,
3841
prev_lock_expired=False,
3942
status=gateway_model.status,
43+
to_be_deleted=gateway_model.to_be_deleted,
4044
)
4145

4246

@@ -182,3 +186,104 @@ async def test_marks_gateway_as_failed_if_fails_to_connect(
182186
events[0].message
183187
== "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)"
184188
)
189+
190+
191+
@pytest.mark.asyncio
192+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
193+
class TestGatewayWorkerDeleted:
194+
async def test_deletes_gateway_and_marks_compute_deleted(
195+
self, test_db, session: AsyncSession, worker: GatewayWorker
196+
):
197+
user = await create_user(session=session)
198+
project = await create_project(session=session)
199+
backend = await create_backend(session=session, project_id=project.id)
200+
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
201+
gateway = await create_gateway(
202+
session=session,
203+
project_id=project.id,
204+
backend_id=backend.id,
205+
gateway_compute_id=gateway_compute.id,
206+
status=GatewayStatus.RUNNING,
207+
)
208+
gateway.lock_token = uuid.uuid4()
209+
gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
210+
gateway.to_be_deleted = True
211+
gateway.deleted_by_user_id = user.id
212+
await session.commit()
213+
214+
with (
215+
patch(
216+
"dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
217+
) as get_backend_mock,
218+
patch(
219+
"dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove"
220+
) as remove_connection_mock,
221+
):
222+
backend_mock = Mock()
223+
backend_mock.compute.return_value = Mock(spec=ComputeMockSpec)
224+
get_backend_mock.return_value = backend_mock
225+
226+
await worker.process(_gateway_to_pipeline_item(gateway))
227+
228+
get_backend_mock.assert_called_once()
229+
backend_mock.compute.return_value.terminate_gateway.assert_called_once()
230+
remove_connection_mock.assert_called_once_with(gateway_compute.ip_address)
231+
232+
await session.refresh(gateway_compute)
233+
res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id))
234+
assert res.scalar_one_or_none() is None
235+
assert gateway_compute.active is False
236+
assert gateway_compute.deleted is True
237+
events = await list_events(session)
238+
assert len(events) == 1
239+
assert events[0].message == "Gateway deleted"
240+
assert events[0].actor_user_id == user.id
241+
242+
async def test_keeps_gateway_if_terminate_fails(
243+
self, test_db, session: AsyncSession, worker: GatewayWorker
244+
):
245+
project = await create_project(session=session)
246+
backend = await create_backend(session=session, project_id=project.id)
247+
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
248+
gateway = await create_gateway(
249+
session=session,
250+
project_id=project.id,
251+
backend_id=backend.id,
252+
gateway_compute_id=gateway_compute.id,
253+
status=GatewayStatus.RUNNING,
254+
)
255+
gateway.lock_token = uuid.uuid4()
256+
gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
257+
gateway.to_be_deleted = True
258+
original_last_processed_at = gateway.last_processed_at
259+
await session.commit()
260+
261+
with (
262+
patch(
263+
"dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
264+
) as get_backend_mock,
265+
patch(
266+
"dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove"
267+
) as remove_connection_mock,
268+
):
269+
backend_mock = Mock()
270+
backend_mock.compute.return_value = Mock(spec=ComputeMockSpec)
271+
backend_mock.compute.return_value.terminate_gateway.side_effect = BackendError(
272+
"Terminate failed"
273+
)
274+
get_backend_mock.return_value = backend_mock
275+
276+
await worker.process(_gateway_to_pipeline_item(gateway))
277+
278+
get_backend_mock.assert_called_once()
279+
backend_mock.compute.return_value.terminate_gateway.assert_called_once()
280+
remove_connection_mock.assert_not_called()
281+
282+
await session.refresh(gateway)
283+
await session.refresh(gateway_compute)
284+
assert gateway.to_be_deleted is True
285+
assert gateway.last_processed_at > original_last_processed_at
286+
assert gateway_compute.active is True
287+
assert gateway_compute.deleted is False
288+
events = await list_events(session)
289+
assert len(events) == 0

src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from dstack._internal.server.background.pipeline_tasks.base import PipelineItem
99
from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker
10+
from dstack._internal.server.models import PlacementGroupModel
1011
from dstack._internal.server.testing.common import (
1112
ComputeMockSpec,
1213
create_fleet,
@@ -20,7 +21,7 @@ def worker() -> PlacementGroupWorker:
2021
return PlacementGroupWorker(queue=Mock(), heartbeater=Mock())
2122

2223

23-
def _placement_group_to_pipeline_item(placement_group) -> PipelineItem:
24+
def _placement_group_to_pipeline_item(placement_group: PlacementGroupModel) -> PipelineItem:
2425
assert placement_group.lock_token is not None
2526
assert placement_group.lock_expires_at is not None
2627
return PipelineItem(

0 commit comments

Comments
 (0)