33from unittest .mock import MagicMock , Mock , patch
44
55import pytest
6+ from sqlalchemy import select
67from sqlalchemy .ext .asyncio import AsyncSession
78
89from dstack ._internal .core .errors import BackendError
1112 GatewayPipelineItem ,
1213 GatewayWorker ,
1314)
15+ from dstack ._internal .server .models import GatewayModel
1416from dstack ._internal .server .testing .common import (
1517 AsyncContextManager ,
1618 ComputeMockSpec ,
1719 create_backend ,
1820 create_gateway ,
1921 create_gateway_compute ,
2022 create_project ,
23+ create_user ,
2124 list_events ,
2225)
2326
@@ -27,7 +30,7 @@ def worker() -> GatewayWorker:
2730 return GatewayWorker (queue = Mock (), heartbeater = Mock ())
2831
2932
30- def _gateway_to_pipeline_item (gateway_model ) -> GatewayPipelineItem :
33+ def _gateway_to_pipeline_item (gateway_model : GatewayModel ) -> GatewayPipelineItem :
3134 assert gateway_model .lock_token is not None
3235 assert gateway_model .lock_expires_at is not None
3336 return GatewayPipelineItem (
@@ -37,6 +40,7 @@ def _gateway_to_pipeline_item(gateway_model) -> GatewayPipelineItem:
3740 lock_expires_at = gateway_model .lock_expires_at ,
3841 prev_lock_expired = False ,
3942 status = gateway_model .status ,
43+ to_be_deleted = gateway_model .to_be_deleted ,
4044 )
4145
4246
@@ -182,3 +186,104 @@ async def test_marks_gateway_as_failed_if_fails_to_connect(
182186 events [0 ].message
183187 == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)"
184188 )
189+
190+
191+ @pytest .mark .asyncio
192+ @pytest .mark .parametrize ("test_db" , ["sqlite" , "postgres" ], indirect = True )
193+ class TestGatewayWorkerDeleted :
194+ async def test_deletes_gateway_and_marks_compute_deleted (
195+ self , test_db , session : AsyncSession , worker : GatewayWorker
196+ ):
197+ user = await create_user (session = session )
198+ project = await create_project (session = session )
199+ backend = await create_backend (session = session , project_id = project .id )
200+ gateway_compute = await create_gateway_compute (session = session , backend_id = backend .id )
201+ gateway = await create_gateway (
202+ session = session ,
203+ project_id = project .id ,
204+ backend_id = backend .id ,
205+ gateway_compute_id = gateway_compute .id ,
206+ status = GatewayStatus .RUNNING ,
207+ )
208+ gateway .lock_token = uuid .uuid4 ()
209+ gateway .lock_expires_at = datetime (2025 , 1 , 2 , 3 , 4 , tzinfo = timezone .utc )
210+ gateway .to_be_deleted = True
211+ gateway .deleted_by_user_id = user .id
212+ await session .commit ()
213+
214+ with (
215+ patch (
216+ "dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
217+ ) as get_backend_mock ,
218+ patch (
219+ "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove"
220+ ) as remove_connection_mock ,
221+ ):
222+ backend_mock = Mock ()
223+ backend_mock .compute .return_value = Mock (spec = ComputeMockSpec )
224+ get_backend_mock .return_value = backend_mock
225+
226+ await worker .process (_gateway_to_pipeline_item (gateway ))
227+
228+ get_backend_mock .assert_called_once ()
229+ backend_mock .compute .return_value .terminate_gateway .assert_called_once ()
230+ remove_connection_mock .assert_called_once_with (gateway_compute .ip_address )
231+
232+ await session .refresh (gateway_compute )
233+ res = await session .execute (select (GatewayModel .id ).where (GatewayModel .id == gateway .id ))
234+ assert res .scalar_one_or_none () is None
235+ assert gateway_compute .active is False
236+ assert gateway_compute .deleted is True
237+ events = await list_events (session )
238+ assert len (events ) == 1
239+ assert events [0 ].message == "Gateway deleted"
240+ assert events [0 ].actor_user_id == user .id
241+
242+ async def test_keeps_gateway_if_terminate_fails (
243+ self , test_db , session : AsyncSession , worker : GatewayWorker
244+ ):
245+ project = await create_project (session = session )
246+ backend = await create_backend (session = session , project_id = project .id )
247+ gateway_compute = await create_gateway_compute (session = session , backend_id = backend .id )
248+ gateway = await create_gateway (
249+ session = session ,
250+ project_id = project .id ,
251+ backend_id = backend .id ,
252+ gateway_compute_id = gateway_compute .id ,
253+ status = GatewayStatus .RUNNING ,
254+ )
255+ gateway .lock_token = uuid .uuid4 ()
256+ gateway .lock_expires_at = datetime (2025 , 1 , 2 , 3 , 4 , tzinfo = timezone .utc )
257+ gateway .to_be_deleted = True
258+ original_last_processed_at = gateway .last_processed_at
259+ await session .commit ()
260+
261+ with (
262+ patch (
263+ "dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
264+ ) as get_backend_mock ,
265+ patch (
266+ "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove"
267+ ) as remove_connection_mock ,
268+ ):
269+ backend_mock = Mock ()
270+ backend_mock .compute .return_value = Mock (spec = ComputeMockSpec )
271+ backend_mock .compute .return_value .terminate_gateway .side_effect = BackendError (
272+ "Terminate failed"
273+ )
274+ get_backend_mock .return_value = backend_mock
275+
276+ await worker .process (_gateway_to_pipeline_item (gateway ))
277+
278+ get_backend_mock .assert_called_once ()
279+ backend_mock .compute .return_value .terminate_gateway .assert_called_once ()
280+ remove_connection_mock .assert_not_called ()
281+
282+ await session .refresh (gateway )
283+ await session .refresh (gateway_compute )
284+ assert gateway .to_be_deleted is True
285+ assert gateway .last_processed_at > original_last_processed_at
286+ assert gateway_compute .active is True
287+ assert gateway_compute .deleted is False
288+ events = await list_events (session )
289+ assert len (events ) == 0
0 commit comments