|
13 | 13 | from sqlalchemy.ext.asyncio import AsyncSession |
14 | 14 |
|
15 | 15 | from dstack._internal import settings |
| 16 | +from dstack._internal.core.errors import GatewayError |
16 | 17 | from dstack._internal.core.models.backends.base import BackendType |
17 | 18 | from dstack._internal.core.models.common import ApplyAction |
18 | 19 | from dstack._internal.core.models.configurations import ( |
@@ -2299,13 +2300,13 @@ async def test_returns_400_if_runs_active( |
2299 | 2300 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
2300 | 2301 | class TestSubmitService: |
2301 | 2302 | @pytest.fixture(autouse=True) |
2302 | | - def mock_gateway_connections(self) -> Generator[None, None, None]: |
| 2303 | + def mock_gateway_connection(self) -> Generator[AsyncMock, None, None]: |
2303 | 2304 | with patch( |
2304 | 2305 | "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" |
2305 | 2306 | ) as get_conn_mock: |
2306 | 2307 | get_conn_mock.return_value.client = Mock() |
2307 | 2308 | get_conn_mock.return_value.client.return_value = AsyncMock() |
2308 | | - yield |
| 2309 | + yield get_conn_mock |
2309 | 2310 |
|
2310 | 2311 | @pytest.mark.asyncio |
2311 | 2312 | @pytest.mark.parametrize( |
@@ -2481,3 +2482,54 @@ async def test_return_error_if_specified_gateway_is_true_and_no_gateway_exists( |
2481 | 2482 | } |
2482 | 2483 | ] |
2483 | 2484 | } |
| 2485 | + |
| 2486 | + @pytest.mark.asyncio |
| 2487 | + async def test_unregister_dangling_service( |
| 2488 | + self, |
| 2489 | + test_db, |
| 2490 | + session: AsyncSession, |
| 2491 | + client: AsyncClient, |
| 2492 | + mock_gateway_connection: AsyncMock, |
| 2493 | + ) -> None: |
| 2494 | + user = await create_user(session=session, global_role=GlobalRole.USER) |
| 2495 | + project = await create_project(session=session, owner=user, name="test-project") |
| 2496 | + await add_project_member( |
| 2497 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 2498 | + ) |
| 2499 | + repo = await create_repo(session=session, project_id=project.id) |
| 2500 | + backend = await create_backend(session=session, project_id=project.id) |
| 2501 | + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) |
| 2502 | + gateway = await create_gateway( |
| 2503 | + session=session, |
| 2504 | + project_id=project.id, |
| 2505 | + backend_id=backend.id, |
| 2506 | + gateway_compute_id=gateway_compute.id, |
| 2507 | + status=GatewayStatus.RUNNING, |
| 2508 | + wildcard_domain="example.com", |
| 2509 | + ) |
| 2510 | + project.default_gateway_id = gateway.id |
| 2511 | + await session.commit() |
| 2512 | + |
| 2513 | + client_mock = ( |
| 2514 | + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value |
| 2515 | + ) |
| 2516 | + client_mock.register_service.side_effect = [ |
| 2517 | + GatewayError("Service test-project/test-service is already registered"), |
| 2518 | + None, # Second call succeeds |
| 2519 | + ] |
| 2520 | + |
| 2521 | + response = await client.post( |
| 2522 | + "/api/project/test-project/runs/submit", |
| 2523 | + headers=get_auth_headers(user.token), |
| 2524 | + json={"run_spec": get_service_run_spec(repo_id=repo.name, run_name="test-service")}, |
| 2525 | + ) |
| 2526 | + |
| 2527 | + assert response.status_code == 200 |
| 2528 | + assert response.json()["service"]["url"] == "https://test-service.example.com" |
| 2529 | + # Verify that unregister_service was called to clean up the dangling service |
| 2530 | + client_mock.unregister_service.assert_called_once_with( |
| 2531 | + project=project.name, |
| 2532 | + run_name="test-service", |
| 2533 | + ) |
| 2534 | + # Verify that register_service was called twice (first failed, then succeeded) |
| 2535 | + assert client_mock.register_service.call_count == 2 |
0 commit comments