|
8 | 8 | from freezegun import freeze_time |
9 | 9 | from sqlalchemy.ext.asyncio import AsyncSession |
10 | 10 |
|
11 | | -from dstack._internal.core.errors import BackendError |
| 11 | +from dstack._internal.core.errors import BackendError, ProvisioningError |
12 | 12 | from dstack._internal.core.models.backends.base import BackendType |
13 | 13 | from dstack._internal.core.models.instances import ( |
14 | 14 | Gpu, |
|
35 | 35 | create_repo, |
36 | 36 | create_run, |
37 | 37 | create_user, |
| 38 | + get_instance_offer_with_availability, |
| 39 | + get_job_provisioning_data, |
38 | 40 | get_remote_connection_info, |
39 | 41 | ) |
40 | 42 | from dstack._internal.utils.common import get_current_datetime |
@@ -557,6 +559,68 @@ async def test_creates_instance( |
557 | 559 | assert instance.total_blocks == expected_blocks |
558 | 560 | assert instance.busy_blocks == 0 |
559 | 561 |
|
| 562 | + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) |
| 563 | + async def test_tries_second_offer_if_first_fails(self, session: AsyncSession, err: Exception): |
| 564 | + project = await create_project(session=session) |
| 565 | + instance = await create_instance( |
| 566 | + session=session, project=project, status=InstanceStatus.PENDING |
| 567 | + ) |
| 568 | + aws_mock = Mock() |
| 569 | + aws_mock.TYPE = BackendType.AWS |
| 570 | + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) |
| 571 | + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) |
| 572 | + aws_mock.compute.return_value.get_offers_cached.return_value = [offer] |
| 573 | + aws_mock.compute.return_value.create_instance.side_effect = err |
| 574 | + gcp_mock = Mock() |
| 575 | + gcp_mock.TYPE = BackendType.GCP |
| 576 | + offer = get_instance_offer_with_availability(backend=BackendType.GCP, price=2.0) |
| 577 | + gcp_mock.compute.return_value = Mock(spec=ComputeMockSpec) |
| 578 | + gcp_mock.compute.return_value.get_offers_cached.return_value = [offer] |
| 579 | + gcp_mock.compute.return_value.create_instance.return_value = get_job_provisioning_data( |
| 580 | + backend=offer.backend, region=offer.region, price=offer.price |
| 581 | + ) |
| 582 | + with patch("dstack._internal.server.services.backends.get_project_backends") as m: |
| 583 | + m.return_value = [aws_mock, gcp_mock] |
| 584 | + await process_instances() |
| 585 | + |
| 586 | + await session.refresh(instance) |
| 587 | + assert instance.status == InstanceStatus.PROVISIONING |
| 588 | + aws_mock.compute.return_value.create_instance.assert_called_once() |
| 589 | + assert instance.backend == BackendType.GCP |
| 590 | + |
| 591 | + @pytest.mark.parametrize("err", [RuntimeError("Unexpected"), ProvisioningError("Expected")]) |
| 592 | + async def test_fails_if_all_offers_fail(self, session: AsyncSession, err: Exception): |
| 593 | + project = await create_project(session=session) |
| 594 | + instance = await create_instance( |
| 595 | + session=session, project=project, status=InstanceStatus.PENDING |
| 596 | + ) |
| 597 | + aws_mock = Mock() |
| 598 | + aws_mock.TYPE = BackendType.AWS |
| 599 | + offer = get_instance_offer_with_availability(backend=BackendType.AWS, price=1.0) |
| 600 | + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) |
| 601 | + aws_mock.compute.return_value.get_offers_cached.return_value = [offer] |
| 602 | + aws_mock.compute.return_value.create_instance.side_effect = err |
| 603 | + with patch("dstack._internal.server.services.backends.get_project_backends") as m: |
| 604 | + m.return_value = [aws_mock] |
| 605 | + await process_instances() |
| 606 | + |
| 607 | + await session.refresh(instance) |
| 608 | + assert instance.status == InstanceStatus.TERMINATED |
| 609 | + assert instance.termination_reason == "All offers failed" |
| 610 | + |
| 611 | + async def test_fails_if_no_offers(self, session: AsyncSession): |
| 612 | + project = await create_project(session=session) |
| 613 | + instance = await create_instance( |
| 614 | + session=session, project=project, status=InstanceStatus.PENDING |
| 615 | + ) |
| 616 | + with patch("dstack._internal.server.services.backends.get_project_backends") as m: |
| 617 | + m.return_value = [] |
| 618 | + await process_instances() |
| 619 | + |
| 620 | + await session.refresh(instance) |
| 621 | + assert instance.status == InstanceStatus.TERMINATED |
| 622 | + assert instance.termination_reason == "No offers found" |
| 623 | + |
560 | 624 |
|
561 | 625 | @pytest.mark.asyncio |
562 | 626 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
|
0 commit comments