|
12 | 12 | from dstack._internal.core.models.common import NetworkMode |
13 | 13 | from dstack._internal.core.models.configurations import DevEnvironmentConfiguration |
14 | 14 | from dstack._internal.core.models.instances import InstanceStatus |
15 | | -from dstack._internal.core.models.profiles import UtilizationPolicy |
| 15 | +from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy |
16 | 16 | from dstack._internal.core.models.runs import ( |
17 | 17 | JobRuntimeData, |
18 | 18 | JobStatus, |
@@ -805,3 +805,76 @@ async def test_gpu_utilization( |
805 | 805 | else: |
806 | 806 | assert job.termination_reason is None |
807 | 807 | assert job.termination_reason_message is None |
| 808 | + |
| 809 | + @pytest.mark.asyncio |
| 810 | + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 811 | + async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession): |
| 812 | + project = await create_project(session=session) |
| 813 | + user = await create_user(session=session) |
| 814 | + repo = await create_repo( |
| 815 | + session=session, |
| 816 | + project_id=project.id, |
| 817 | + ) |
| 818 | + run_spec = get_run_spec( |
| 819 | + run_name="test-run", |
| 820 | + repo_id=repo.name, |
| 821 | + ) |
| 822 | + run_spec.configuration.startup_order = StartupOrder.WORKERS_FIRST |
| 823 | + run = await create_run( |
| 824 | + session=session, |
| 825 | + project=project, |
| 826 | + repo=repo, |
| 827 | + user=user, |
| 828 | + run_spec=run_spec, |
| 829 | + ) |
| 830 | + instance1 = await create_instance( |
| 831 | + session=session, |
| 832 | + project=project, |
| 833 | + status=InstanceStatus.BUSY, |
| 834 | + ) |
| 835 | + instance2 = await create_instance( |
| 836 | + session=session, |
| 837 | + project=project, |
| 838 | + status=InstanceStatus.BUSY, |
| 839 | + ) |
| 840 | + job_provisioning_data = get_job_provisioning_data(dockerized=False) |
| 841 | + master_job = await create_job( |
| 842 | + session=session, |
| 843 | + run=run, |
| 844 | + status=JobStatus.PROVISIONING, |
| 845 | + job_provisioning_data=job_provisioning_data, |
| 846 | + instance_assigned=True, |
| 847 | + instance=instance1, |
| 848 | + job_num=0, |
| 849 | + last_processed_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), |
| 850 | + ) |
| 851 | + worker_job = await create_job( |
| 852 | + session=session, |
| 853 | + run=run, |
| 854 | + status=JobStatus.PROVISIONING, |
| 855 | + job_provisioning_data=job_provisioning_data, |
| 856 | + instance_assigned=True, |
| 857 | + instance=instance2, |
| 858 | + job_num=1, |
| 859 | + last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), |
| 860 | + ) |
| 861 | + await process_running_jobs() |
| 862 | + await session.refresh(master_job) |
| 863 | + assert master_job.status == JobStatus.PROVISIONING |
| 864 | + worker_job.status = JobStatus.RUNNING |
| 865 | + # To guarantee master_job is processed next |
| 866 | + master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) |
| 867 | + await session.commit() |
| 868 | + with ( |
| 869 | + patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), |
| 870 | + patch( |
| 871 | + "dstack._internal.server.services.runner.client.RunnerClient" |
| 872 | + ) as RunnerClientMock, |
| 873 | + ): |
| 874 | + runner_client_mock = RunnerClientMock.return_value |
| 875 | + runner_client_mock.healthcheck.return_value = HealthcheckResponse( |
| 876 | + service="dstack-runner", version="0.0.1.dev2" |
| 877 | + ) |
| 878 | + await process_running_jobs() |
| 879 | + await session.refresh(master_job) |
| 880 | + assert master_job.status == JobStatus.RUNNING |
0 commit comments