|
4 | 4 | from datetime import datetime, timedelta, timezone |
5 | 5 | from pathlib import Path |
6 | 6 | from typing import Optional |
7 | | -from unittest.mock import AsyncMock, MagicMock, Mock, patch |
| 7 | +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch |
8 | 8 |
|
9 | 9 | import pytest |
10 | 10 | from freezegun import freeze_time |
|
21 | 21 | ProbeConfig, |
22 | 22 | ServiceConfiguration, |
23 | 23 | ) |
| 24 | +from dstack._internal.core.models.gateways import GatewayStatus |
24 | 25 | from dstack._internal.core.models.instances import InstanceStatus |
25 | 26 | from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy |
26 | 27 | from dstack._internal.core.models.runs import ( |
|
52 | 53 | from dstack._internal.server.services.runner.ssh import SSHTunnel |
53 | 54 | from dstack._internal.server.services.volumes import volume_model_to_volume |
54 | 55 | from dstack._internal.server.testing.common import ( |
| 56 | + create_backend, |
| 57 | + create_export, |
| 58 | + create_fleet, |
| 59 | + create_gateway, |
| 60 | + create_gateway_compute, |
55 | 61 | create_instance, |
56 | 62 | create_job, |
57 | 63 | create_job_metrics_point, |
@@ -1635,6 +1641,169 @@ async def test_registers_service_replica_only_after_probes_pass( |
1635 | 1641 | assert not job.registered |
1636 | 1642 | assert not events |
1637 | 1643 |
|
| 1644 | + async def test_registers_service_replica_in_gateway( |
| 1645 | + self, |
| 1646 | + test_db, |
| 1647 | + session: AsyncSession, |
| 1648 | + worker: JobRunningWorker, |
| 1649 | + ssh_tunnel_mock: Mock, |
| 1650 | + shim_client_mock: Mock, |
| 1651 | + runner_client_mock: Mock, |
| 1652 | + mock_gateway_connection: AsyncMock, |
| 1653 | + ): |
| 1654 | + user = await create_user(session=session) |
| 1655 | + project = await create_project(session=session, owner=user) |
| 1656 | + repo = await create_repo(session=session, project_id=project.id) |
| 1657 | + backend = await create_backend(session=session, project_id=project.id) |
| 1658 | + gateway_compute = await create_gateway_compute( |
| 1659 | + session=session, |
| 1660 | + backend_id=backend.id, |
| 1661 | + ) |
| 1662 | + gateway = await create_gateway( |
| 1663 | + session=session, |
| 1664 | + project_id=project.id, |
| 1665 | + backend_id=backend.id, |
| 1666 | + gateway_compute_id=gateway_compute.id, |
| 1667 | + status=GatewayStatus.RUNNING, |
| 1668 | + name="test-gateway", |
| 1669 | + wildcard_domain="example.com", |
| 1670 | + ) |
| 1671 | + run = await create_run( |
| 1672 | + session=session, |
| 1673 | + project=project, |
| 1674 | + repo=repo, |
| 1675 | + user=user, |
| 1676 | + run_spec=get_run_spec( |
| 1677 | + run_name="test", |
| 1678 | + repo_id=repo.name, |
| 1679 | + configuration=ServiceConfiguration( |
| 1680 | + port=80, image="ubuntu", gateway="test-gateway" |
| 1681 | + ), |
| 1682 | + ), |
| 1683 | + gateway=gateway, |
| 1684 | + ) |
| 1685 | + fleet = await create_fleet(session=session, project=project) |
| 1686 | + instance = await create_instance( |
| 1687 | + session=session, |
| 1688 | + project=project, |
| 1689 | + status=InstanceStatus.BUSY, |
| 1690 | + fleet=fleet, |
| 1691 | + ) |
| 1692 | + job = await create_job( |
| 1693 | + session=session, |
| 1694 | + run=run, |
| 1695 | + status=JobStatus.PULLING, |
| 1696 | + job_provisioning_data=get_job_provisioning_data(dockerized=True), |
| 1697 | + instance=instance, |
| 1698 | + instance_assigned=True, |
| 1699 | + ) |
| 1700 | + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING |
| 1701 | + |
| 1702 | + await _process_job(session, worker, job) |
| 1703 | + |
| 1704 | + await session.refresh(job) |
| 1705 | + assert job.status == JobStatus.RUNNING |
| 1706 | + assert job.registered |
| 1707 | + events = await list_events(session) |
| 1708 | + assert {event.message for event in events} == { |
| 1709 | + "Job status changed PULLING -> RUNNING", |
| 1710 | + "Service replica registered to receive requests", |
| 1711 | + } |
| 1712 | + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( |
| 1713 | + run=ANY, |
| 1714 | + job_spec=ANY, |
| 1715 | + job_submission=ANY, |
| 1716 | + instance_project_ssh_private_key=None, |
| 1717 | + ssh_head_proxy=None, |
| 1718 | + ssh_head_proxy_private_key=None, |
| 1719 | + ) |
| 1720 | + |
| 1721 | + async def test_registers_service_replica_in_gateway_when_running_on_imported_instance( |
| 1722 | + self, |
| 1723 | + test_db, |
| 1724 | + session: AsyncSession, |
| 1725 | + worker: JobRunningWorker, |
| 1726 | + ssh_tunnel_mock: Mock, |
| 1727 | + shim_client_mock: Mock, |
| 1728 | + runner_client_mock: Mock, |
| 1729 | + mock_gateway_connection: AsyncMock, |
| 1730 | + ): |
| 1731 | + user = await create_user(session=session) |
| 1732 | + exporter_project = await create_project( |
| 1733 | + session=session, name="exporter", owner=user, ssh_private_key="exporter-private-key" |
| 1734 | + ) |
| 1735 | + importer_project = await create_project(session=session, name="importer", owner=user) |
| 1736 | + fleet = await create_fleet(session=session, project=exporter_project) |
| 1737 | + instance = await create_instance( |
| 1738 | + session=session, |
| 1739 | + project=exporter_project, |
| 1740 | + status=InstanceStatus.BUSY, |
| 1741 | + fleet=fleet, |
| 1742 | + ) |
| 1743 | + await create_export( |
| 1744 | + session=session, |
| 1745 | + exporter_project=exporter_project, |
| 1746 | + importer_projects=[importer_project], |
| 1747 | + exported_fleets=[fleet], |
| 1748 | + ) |
| 1749 | + repo = await create_repo(session=session, project_id=importer_project.id) |
| 1750 | + backend = await create_backend(session=session, project_id=importer_project.id) |
| 1751 | + gateway_compute = await create_gateway_compute( |
| 1752 | + session=session, |
| 1753 | + backend_id=backend.id, |
| 1754 | + ) |
| 1755 | + gateway = await create_gateway( |
| 1756 | + session=session, |
| 1757 | + project_id=importer_project.id, |
| 1758 | + backend_id=backend.id, |
| 1759 | + gateway_compute_id=gateway_compute.id, |
| 1760 | + status=GatewayStatus.RUNNING, |
| 1761 | + name="test-gateway", |
| 1762 | + wildcard_domain="example.com", |
| 1763 | + ) |
| 1764 | + run = await create_run( |
| 1765 | + session=session, |
| 1766 | + project=importer_project, |
| 1767 | + repo=repo, |
| 1768 | + user=user, |
| 1769 | + run_spec=get_run_spec( |
| 1770 | + run_name="test", |
| 1771 | + repo_id=repo.name, |
| 1772 | + configuration=ServiceConfiguration( |
| 1773 | + port=80, image="ubuntu", gateway="test-gateway" |
| 1774 | + ), |
| 1775 | + ), |
| 1776 | + gateway=gateway, |
| 1777 | + ) |
| 1778 | + job = await create_job( |
| 1779 | + session=session, |
| 1780 | + run=run, |
| 1781 | + status=JobStatus.PULLING, |
| 1782 | + job_provisioning_data=get_job_provisioning_data(dockerized=True), |
| 1783 | + instance=instance, |
| 1784 | + instance_assigned=True, |
| 1785 | + ) |
| 1786 | + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING |
| 1787 | + |
| 1788 | + await _process_job(session, worker, job) |
| 1789 | + |
| 1790 | + await session.refresh(job) |
| 1791 | + assert job.status == JobStatus.RUNNING |
| 1792 | + assert job.registered |
| 1793 | + events = await list_events(session) |
| 1794 | + assert {event.message for event in events} == { |
| 1795 | + "Job status changed PULLING -> RUNNING", |
| 1796 | + "Service replica registered to receive requests", |
| 1797 | + } |
| 1798 | + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( |
| 1799 | + run=ANY, |
| 1800 | + job_spec=ANY, |
| 1801 | + job_submission=ANY, |
| 1802 | + instance_project_ssh_private_key="exporter-private-key", |
| 1803 | + ssh_head_proxy=None, |
| 1804 | + ssh_head_proxy_private_key=None, |
| 1805 | + ) |
| 1806 | + |
1638 | 1807 | async def test_apply_skips_probe_insert_when_lock_token_changes_after_processing( |
1639 | 1808 | self, |
1640 | 1809 | test_db, |
|
0 commit comments