|
1 | | -from unittest.mock import patch |
| 1 | +from unittest.mock import Mock, patch |
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 | from httpx import AsyncClient |
5 | 5 | from sqlalchemy.ext.asyncio import AsyncSession |
6 | 6 |
|
| 7 | +from dstack._internal.core.errors import DstackError |
7 | 8 | from dstack._internal.core.models.backends.base import BackendType |
8 | 9 | from dstack._internal.core.models.users import GlobalRole, ProjectRole |
9 | 10 | from dstack._internal.server.services.projects import add_project_member |
10 | 11 | from dstack._internal.server.testing.common import ( |
| 12 | + ComputeMockSpec, |
11 | 13 | clear_events, |
12 | 14 | create_backend, |
13 | 15 | create_gateway, |
|
18 | 20 | list_events, |
19 | 21 | ) |
20 | 22 | from dstack._internal.server.testing.matchers import SomeUUID4Str |
| 23 | +from dstack._internal.settings import FeatureFlags |
| 24 | + |
| 25 | + |
| 26 | +@pytest.fixture |
| 27 | +def patch_pipeline_processing_flag(monkeypatch: pytest.MonkeyPatch): |
| 28 | + def _apply(enabled: bool): |
| 29 | + monkeypatch.setattr(FeatureFlags, "PIPELINE_PROCESSING_ENABLED", enabled) |
| 30 | + |
| 31 | + return _apply |
21 | 32 |
|
22 | 33 |
|
23 | 34 | class TestListAndGetGateways: |
@@ -453,6 +464,12 @@ async def test_only_admin_can_delete( |
453 | 464 | ) |
454 | 465 | assert response.status_code == 403 |
455 | 466 |
|
| 467 | + |
| 468 | +class TestDeleteGatewayPipelineEnabled: |
| 469 | + @pytest.fixture(autouse=True) |
| 470 | + def _pipeline_processing_enabled(self, patch_pipeline_processing_flag): |
| 471 | + patch_pipeline_processing_flag(True) |
| 472 | + |
456 | 473 | @pytest.mark.asyncio |
457 | 474 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
458 | 475 | async def test_marks_gateways_to_be_deleted( |
@@ -519,6 +536,110 @@ async def test_marks_gateways_to_be_deleted( |
519 | 536 | assert all(e.actor_user_id == user.id for e in events) |
520 | 537 |
|
521 | 538 |
|
| 539 | +class TestDeleteGatewayPipelineDisabled: |
| 540 | + @pytest.fixture(autouse=True) |
| 541 | + def _pipeline_processing_disabled(self, patch_pipeline_processing_flag): |
| 542 | + patch_pipeline_processing_flag(False) |
| 543 | + |
| 544 | + @pytest.mark.asyncio |
| 545 | + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 546 | + async def test_deletes_gateways_synchronously( |
| 547 | + self, test_db, session: AsyncSession, client: AsyncClient |
| 548 | + ): |
| 549 | + user = await create_user(session, global_role=GlobalRole.USER) |
| 550 | + project = await create_project(session) |
| 551 | + await add_project_member( |
| 552 | + session=session, project=project, user=user, project_role=ProjectRole.ADMIN |
| 553 | + ) |
| 554 | + backend_aws = await create_backend(session, project.id) |
| 555 | + backend_gcp = await create_backend(session, project.id, backend_type=BackendType.GCP) |
| 556 | + gateway_compute_aws = await create_gateway_compute( |
| 557 | + session=session, |
| 558 | + backend_id=backend_aws.id, |
| 559 | + ) |
| 560 | + gateway_aws = await create_gateway( |
| 561 | + session=session, |
| 562 | + project_id=project.id, |
| 563 | + backend_id=backend_aws.id, |
| 564 | + name="gateway-aws", |
| 565 | + gateway_compute_id=gateway_compute_aws.id, |
| 566 | + ) |
| 567 | + gateway_compute_gcp = await create_gateway_compute( |
| 568 | + session=session, |
| 569 | + backend_id=backend_gcp.id, |
| 570 | + ) |
| 571 | + gateway_gcp = await create_gateway( |
| 572 | + session=session, |
| 573 | + project_id=project.id, |
| 574 | + backend_id=backend_gcp.id, |
| 575 | + name="gateway-gcp", |
| 576 | + gateway_compute_id=gateway_compute_gcp.id, |
| 577 | + ) |
| 578 | + with patch( |
| 579 | + "dstack._internal.server.services.gateways.get_project_backend_by_type_or_error" |
| 580 | + ) as m: |
| 581 | + aws = Mock() |
| 582 | + aws.compute.return_value = Mock(spec=ComputeMockSpec) |
| 583 | + aws.compute.return_value.terminate_gateway.return_value = None # success |
| 584 | + gcp = Mock() |
| 585 | + gcp.compute.return_value = Mock(spec=ComputeMockSpec) |
| 586 | + gcp.compute.return_value.terminate_gateway.side_effect = DstackError() # fail |
| 587 | + |
| 588 | + def get_backend(project, backend_type): |
| 589 | + return {BackendType.AWS: aws, BackendType.GCP: gcp}[backend_type] |
| 590 | + |
| 591 | + m.side_effect = get_backend |
| 592 | + |
| 593 | + response = await client.post( |
| 594 | + f"/api/project/{project.name}/gateways/delete", |
| 595 | + json={"names": [gateway_aws.name, gateway_gcp.name]}, |
| 596 | + headers=get_auth_headers(user.token), |
| 597 | + ) |
| 598 | + aws.compute.return_value.terminate_gateway.assert_called_once() |
| 599 | + gcp.compute.return_value.terminate_gateway.assert_called_once() |
| 600 | + assert response.status_code == 200 |
| 601 | + |
| 602 | + response = await client.post( |
| 603 | + f"/api/project/{project.name}/gateways/list", |
| 604 | + headers=get_auth_headers(user.token), |
| 605 | + ) |
| 606 | + assert response.status_code == 200 |
| 607 | + assert response.json() == [ |
| 608 | + { |
| 609 | + "id": str(gateway_gcp.id), |
| 610 | + "backend": backend_gcp.type.value, |
| 611 | + "created_at": response.json()[0]["created_at"], |
| 612 | + "default": False, |
| 613 | + "status": "submitted", |
| 614 | + "status_message": None, |
| 615 | + "instance_id": gateway_compute_gcp.instance_id, |
| 616 | + "ip_address": gateway_compute_gcp.ip_address, |
| 617 | + "hostname": gateway_compute_gcp.ip_address, |
| 618 | + "name": gateway_gcp.name, |
| 619 | + "region": gateway_gcp.region, |
| 620 | + "wildcard_domain": gateway_gcp.wildcard_domain, |
| 621 | + "configuration": { |
| 622 | + "type": "gateway", |
| 623 | + "name": gateway_gcp.name, |
| 624 | + "backend": backend_gcp.type.value, |
| 625 | + "region": gateway_gcp.region, |
| 626 | + "instance_type": None, |
| 627 | + "router": None, |
| 628 | + "domain": gateway_gcp.wildcard_domain, |
| 629 | + "default": False, |
| 630 | + "public_ip": True, |
| 631 | + "certificate": {"type": "lets-encrypt"}, |
| 632 | + "tags": None, |
| 633 | + }, |
| 634 | + } |
| 635 | + ] |
| 636 | + |
| 637 | + events = await list_events(session) |
| 638 | + assert len(events) == 1 |
| 639 | + assert events[0].message == "Gateway deleted" |
| 640 | + assert events[0].targets[0].entity_name == "gateway-aws" |
| 641 | + |
| 642 | + |
522 | 643 | class TestUpdateGateway: |
523 | 644 | @pytest.mark.asyncio |
524 | 645 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
|
0 commit comments